Mistral embedding engine support ()

* add mistral embedding engine support

* remove console log + fix data handling onboarding

* update data handling description

---------

Co-authored-by: Timothy Carambat <rambat1010@gmail.com>
This commit is contained in:
Sean Hatfield 2024-11-21 11:05:55 -08:00 committed by GitHub
parent 246152c024
commit 9f38b9337b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 109 additions and 0 deletions
frontend/src
components/EmbeddingSelection/MistralAiOptions
pages
GeneralSettings/EmbeddingPreference
OnboardingFlow/Steps/DataHandling
server/utils
EmbeddingEngines/mistral
helpers

View file

@ -0,0 +1,46 @@
export default function MistralAiOptions({ settings }) {
return (
<div className="w-full flex flex-col gap-y-4">
<div className="w-full flex items-center gap-[36px] mt-1.5">
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
API Key
</label>
<input
type="password"
name="MistralAiApiKey"
className="bg-theme-settings-input-bg text-white placeholder:text-theme-settings-input-placeholder text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5"
placeholder="Mistral AI API Key"
defaultValue={settings?.MistralApiKey ? "*".repeat(20) : ""}
required={true}
autoComplete="off"
spellCheck={false}
/>
</div>
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Model Preference
</label>
<select
name="EmbeddingModelPref"
required={true}
defaultValue={settings?.EmbeddingModelPref}
className="bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
<optgroup label="Available embedding models">
{[
"mistral-embed",
].map((model) => {
return (
<option key={model} value={model}>
{model}
</option>
);
})}
</optgroup>
</select>
</div>
</div>
</div>
);
}

View file

@ -13,6 +13,7 @@ import CohereLogo from "@/media/llmprovider/cohere.png";
import VoyageAiLogo from "@/media/embeddingprovider/voyageai.png"; import VoyageAiLogo from "@/media/embeddingprovider/voyageai.png";
import LiteLLMLogo from "@/media/llmprovider/litellm.png"; import LiteLLMLogo from "@/media/llmprovider/litellm.png";
import GenericOpenAiLogo from "@/media/llmprovider/generic-openai.png"; import GenericOpenAiLogo from "@/media/llmprovider/generic-openai.png";
import MistralAiLogo from "@/media/llmprovider/mistral.jpeg";
import PreLoader from "@/components/Preloader"; import PreLoader from "@/components/Preloader";
import ChangeWarningModal from "@/components/ChangeWarning"; import ChangeWarningModal from "@/components/ChangeWarning";
@ -33,6 +34,7 @@ import { useModal } from "@/hooks/useModal";
import ModalWrapper from "@/components/ModalWrapper"; import ModalWrapper from "@/components/ModalWrapper";
import CTAButton from "@/components/lib/CTAButton"; import CTAButton from "@/components/lib/CTAButton";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import MistralAiOptions from "@/components/EmbeddingSelection/MistralAiOptions";
const EMBEDDERS = [ const EMBEDDERS = [
{ {
@ -100,6 +102,13 @@ const EMBEDDERS = [
options: (settings) => <LiteLLMOptions settings={settings} />, options: (settings) => <LiteLLMOptions settings={settings} />,
description: "Run powerful embedding models from LiteLLM.", description: "Run powerful embedding models from LiteLLM.",
}, },
{
name: "Mistral AI",
value: "mistral",
logo: MistralAiLogo,
options: (settings) => <MistralAiOptions settings={settings} />,
description: "Run powerful embedding models from Mistral AI.",
},
{ {
name: "Generic OpenAI", name: "Generic OpenAI",
value: "generic-openai", value: "generic-openai",

View file

@ -349,6 +349,13 @@ export const EMBEDDING_ENGINE_PRIVACY = {
], ],
logo: VoyageAiLogo, logo: VoyageAiLogo,
}, },
mistral: {
name: "Mistral AI",
description: [
"Data sent to Mistral AI's servers is shared according to the terms of service of https://mistral.ai.",
],
logo: MistralLogo,
},
litellm: { litellm: {
name: "LiteLLM", name: "LiteLLM",
description: [ description: [

View file

@ -0,0 +1,43 @@
class MistralEmbedder {
constructor() {
if (!process.env.MISTRAL_API_KEY)
throw new Error("No Mistral API key was set.");
const { OpenAI: OpenAIApi } = require("openai");
this.openai = new OpenAIApi({
baseURL: "https://api.mistral.ai/v1",
apiKey: process.env.MISTRAL_API_KEY ?? null,
});
this.model = process.env.EMBEDDING_MODEL_PREF || "mistral-embed";
}
async embedTextInput(textInput) {
try {
const response = await this.openai.embeddings.create({
model: this.model,
input: textInput,
});
return response?.data[0]?.embedding || [];
} catch (error) {
console.error("Failed to get embedding from Mistral.", error.message);
return [];
}
}
async embedChunks(textChunks = []) {
try {
const response = await this.openai.embeddings.create({
model: this.model,
input: textChunks,
});
return response?.data?.map((emb) => emb.embedding) || [];
} catch (error) {
console.error("Failed to get embeddings from Mistral.", error.message);
return new Array(textChunks.length).fill([]);
}
}
}
module.exports = {
MistralEmbedder,
};

View file

@ -214,6 +214,9 @@ function getEmbeddingEngineSelection() {
case "litellm": case "litellm":
const { LiteLLMEmbedder } = require("../EmbeddingEngines/liteLLM"); const { LiteLLMEmbedder } = require("../EmbeddingEngines/liteLLM");
return new LiteLLMEmbedder(); return new LiteLLMEmbedder();
case "mistral":
const { MistralEmbedder } = require("../EmbeddingEngines/mistral");
return new MistralEmbedder();
case "generic-openai": case "generic-openai":
const { const {
GenericOpenAiEmbedder, GenericOpenAiEmbedder,

View file

@ -753,6 +753,7 @@ function supportedEmbeddingModel(input = "") {
"voyageai", "voyageai",
"litellm", "litellm",
"generic-openai", "generic-openai",
"mistral",
]; ];
return supported.includes(input) return supported.includes(input)
? null ? null