diff --git a/frontend/src/components/EmbeddingSelection/MistralAiOptions/index.jsx b/frontend/src/components/EmbeddingSelection/MistralAiOptions/index.jsx new file mode 100644 index 000000000..6012b3192 --- /dev/null +++ b/frontend/src/components/EmbeddingSelection/MistralAiOptions/index.jsx @@ -0,0 +1,46 @@ +export default function MistralAiOptions({ settings }) { + return ( + <div className="w-full flex flex-col gap-y-4"> + <div className="w-full flex items-center gap-[36px] mt-1.5"> + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-3"> + API Key + </label> + <input + type="password" + name="MistralAiApiKey" + className="bg-theme-settings-input-bg text-white placeholder:text-theme-settings-input-placeholder text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5" + placeholder="Mistral AI API Key" + defaultValue={settings?.MistralApiKey ? "*".repeat(20) : ""} + required={true} + autoComplete="off" + spellCheck={false} + /> + </div> + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-3"> + Model Preference + </label> + <select + name="EmbeddingModelPref" + required={true} + defaultValue={settings?.EmbeddingModelPref} + className="bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" + > + <optgroup label="Available embedding models"> + {[ + "mistral-embed", + ].map((model) => { + return ( + <option key={model} value={model}> + {model} + </option> + ); + })} + </optgroup> + </select> + </div> + </div> + </div> + ); +} diff --git a/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx b/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx index 4f3dd8ef5..f2f3884b4 100644 --- a/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx +++ b/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx @@ -13,6 +13,7 @@ import CohereLogo from "@/media/llmprovider/cohere.png"; import VoyageAiLogo from "@/media/embeddingprovider/voyageai.png"; import LiteLLMLogo from "@/media/llmprovider/litellm.png"; import GenericOpenAiLogo from "@/media/llmprovider/generic-openai.png"; +import MistralAiLogo from "@/media/llmprovider/mistral.jpeg"; import PreLoader from "@/components/Preloader"; import ChangeWarningModal from "@/components/ChangeWarning"; @@ -33,6 +34,7 @@ import { useModal } from "@/hooks/useModal"; import ModalWrapper from "@/components/ModalWrapper"; import CTAButton from "@/components/lib/CTAButton"; import { useTranslation } from "react-i18next"; +import MistralAiOptions from "@/components/EmbeddingSelection/MistralAiOptions"; const EMBEDDERS = [ { @@ -100,6 +102,13 @@ const EMBEDDERS = [ options: (settings) => <LiteLLMOptions settings={settings} />, description: "Run powerful embedding models from LiteLLM.", }, + { + name: "Mistral AI", + value: "mistral", + logo: MistralAiLogo, + options: (settings) => <MistralAiOptions settings={settings} />, + description: "Run powerful embedding models from Mistral AI.", + }, { name: "Generic OpenAI", value: "generic-openai", diff --git a/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx b/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx index 44fbaed66..ab83a5af2 100644 --- a/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx +++ b/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx @@ -349,6 +349,13 @@ export const EMBEDDING_ENGINE_PRIVACY = { ], logo: VoyageAiLogo, }, + mistral: { + name: "Mistral AI", + description: [ + "Data sent to Mistral AI's servers is shared according to the terms of service of https://mistral.ai.", + ], + logo: MistralLogo, + }, litellm: { name: "LiteLLM", description: [ diff --git a/server/utils/EmbeddingEngines/mistral/index.js b/server/utils/EmbeddingEngines/mistral/index.js new file mode 100644 index 000000000..1d4f73514 --- /dev/null +++ b/server/utils/EmbeddingEngines/mistral/index.js @@ -0,0 +1,43 @@ +class MistralEmbedder { + constructor() { + if (!process.env.MISTRAL_API_KEY) + throw new Error("No Mistral API key was set."); + + const { OpenAI: OpenAIApi } = require("openai"); + this.openai = new OpenAIApi({ + baseURL: "https://api.mistral.ai/v1", + apiKey: process.env.MISTRAL_API_KEY ?? null, + }); + this.model = process.env.EMBEDDING_MODEL_PREF || "mistral-embed"; + } + + async embedTextInput(textInput) { + try { + const response = await this.openai.embeddings.create({ + model: this.model, + input: textInput, + }); + return response?.data[0]?.embedding || []; + } catch (error) { + console.error("Failed to get embedding from Mistral.", error.message); + return []; + } + } + + async embedChunks(textChunks = []) { + try { + const response = await this.openai.embeddings.create({ + model: this.model, + input: textChunks, + }); + return response?.data?.map((emb) => emb.embedding) || []; + } catch (error) { + console.error("Failed to get embeddings from Mistral.", error.message); + return new Array(textChunks.length).fill([]); + } + } +} + +module.exports = { + MistralEmbedder, +}; \ No newline at end of file diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js index 57ec191e7..cbf07fbd0 100644 --- a/server/utils/helpers/index.js +++ b/server/utils/helpers/index.js @@ -214,6 +214,9 @@ function getEmbeddingEngineSelection() { case "litellm": const { LiteLLMEmbedder } = require("../EmbeddingEngines/liteLLM"); return new LiteLLMEmbedder(); + case "mistral": + const { MistralEmbedder } = require("../EmbeddingEngines/mistral"); + return new MistralEmbedder(); case "generic-openai": const { GenericOpenAiEmbedder, diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js index ede372427..d547930a5 100644 --- a/server/utils/helpers/updateENV.js +++ b/server/utils/helpers/updateENV.js @@ -753,6 +753,7 @@ function supportedEmbeddingModel(input = "") { "voyageai", "litellm", "generic-openai", + "mistral", ]; return supported.includes(input) ? null