Add support for gemini authenticated models endpoint ()

* Add support for gemini authenticated models endpoint
add customModels entry
add un-authed fallback to default listing
separate models by expiermental status
resolves 

* add back improved logic for apiVersion decision making
This commit is contained in:
Timothy Carambat 2024-12-17 15:20:26 -08:00 committed by GitHub
parent 71cd5e5b28
commit b082c8e441
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 219 additions and 110 deletions
frontend/src/components/LLMSelection/GeminiLLMOptions
server/utils

View file

@ -1,4 +1,10 @@
import System from "@/models/system";
import { useEffect, useState } from "react";
export default function GeminiLLMOptions({ settings }) {
const [inputValue, setInputValue] = useState(settings?.GeminiLLMApiKey);
const [geminiApiKey, setGeminiApiKey] = useState(settings?.GeminiLLMApiKey);
return (
<div className="w-full flex flex-col">
<div className="w-full flex items-center gap-[36px] mt-1.5">
@ -15,56 +21,14 @@ export default function GeminiLLMOptions({ settings }) {
required={true}
autoComplete="off"
spellCheck={false}
onChange={(e) => setInputValue(e.target.value)}
onBlur={() => setGeminiApiKey(inputValue)}
/>
</div>
{!settings?.credentialsOnly && (
<>
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GeminiLLMModelPref"
defaultValue={settings?.GeminiLLMModelPref || "gemini-pro"}
required={true}
className="border-none bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
<optgroup label="Stable Models">
{[
"gemini-pro",
"gemini-1.0-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
].map((model) => {
return (
<option key={model} value={model}>
{model}
</option>
);
})}
</optgroup>
<optgroup label="Experimental Models">
{[
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
].map((model) => {
return (
<option key={model} value={model}>
{model}
</option>
);
})}
</optgroup>
</select>
</div>
<GeminiModelSelection apiKey={geminiApiKey} settings={settings} />
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Safety Setting
@ -91,3 +55,79 @@ export default function GeminiLLMOptions({ settings }) {
</div>
);
}
function GeminiModelSelection({ apiKey, settings }) {
const [groupedModels, setGroupedModels] = useState({});
const [loading, setLoading] = useState(true);
useEffect(() => {
async function findCustomModels() {
setLoading(true);
const { models } = await System.customModels("gemini", apiKey);
if (models?.length > 0) {
const modelsByOrganization = models.reduce((acc, model) => {
acc[model.experimental ? "Experimental" : "Stable"] =
acc[model.experimental ? "Experimental" : "Stable"] || [];
acc[model.experimental ? "Experimental" : "Stable"].push(model);
return acc;
}, {});
setGroupedModels(modelsByOrganization);
}
setLoading(false);
}
findCustomModels();
}, [apiKey]);
if (loading) {
return (
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GeminiLLMModelPref"
disabled={true}
className="border-none bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
<option disabled={true} selected={true}>
-- loading available models --
</option>
</select>
</div>
);
}
return (
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GeminiLLMModelPref"
required={true}
className="border-none bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
{Object.keys(groupedModels)
.sort((a, b) => {
if (a === "Stable") return -1;
if (b === "Stable") return 1;
return a.localeCompare(b);
})
.map((organization) => (
<optgroup key={organization} label={organization}>
{groupedModels[organization].map((model) => (
<option
key={model.id}
value={model.id}
selected={settings?.GeminiLLMModelPref === model.id}
>
{model.name}
</option>
))}
</optgroup>
))}
</select>
</div>
);
}

View file

@ -0,0 +1,46 @@
const { MODEL_MAP } = require("../modelMap");
const stableModels = [
"gemini-pro",
"gemini-1.0-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
];
const experimentalModels = [
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
];
// There are some models that are only available in the v1beta API
// and some models that are only available in the v1 API
// generally, v1beta models have `exp` in the name, but not always
// so we check for both against a static list as well.
const v1BetaModels = ["gemini-1.5-pro-latest", "gemini-1.5-flash-latest"];
const defaultGeminiModels = [
...stableModels.map((model) => ({
id: model,
name: model,
contextWindow: MODEL_MAP.gemini[model],
experimental: false,
})),
...experimentalModels.map((model) => ({
id: model,
name: model,
contextWindow: MODEL_MAP.gemini[model],
experimental: true,
})),
];
module.exports = {
defaultGeminiModels,
v1BetaModels,
};

View file

@ -7,6 +7,7 @@ const {
clientAbortedHandler,
} = require("../../helpers/chat/responses");
const { MODEL_MAP } = require("../modelMap");
const { defaultGeminiModels, v1BetaModels } = require("./defaultModals");
class GeminiLLM {
constructor(embedder = null, modelPreference = null) {
@ -21,22 +22,17 @@ class GeminiLLM {
this.gemini = genAI.getGenerativeModel(
{ model: this.model },
{
// Gemini-1.5-pro-* and Gemini-1.5-flash are only available on the v1beta API.
apiVersion: [
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
].includes(this.model)
? "v1beta"
: "v1",
apiVersion:
/**
* There are some models that are only available in the v1beta API
* and some models that are only available in the v1 API
* generally, v1beta models have `exp` in the name, but not always
* so we check for both against a static list as well.
* @see {v1BetaModels}
*/
this.model.includes("exp") || v1BetaModels.includes(this.model)
? "v1beta"
: "v1",
}
);
this.limits = {
@ -48,6 +44,11 @@ class GeminiLLM {
this.embedder = embedder ?? new NativeEmbedder();
this.defaultTemp = 0.7; // not used for Gemini
this.safetyThreshold = this.#fetchSafetyThreshold();
this.#log(`Initialized with model: ${this.model}`);
}
#log(text, ...args) {
console.log(`\x1b[32m[GeminiLLM]\x1b[0m ${text}`, ...args);
}
#appendContext(contextTexts = []) {
@ -109,25 +110,63 @@ class GeminiLLM {
return MODEL_MAP.gemini[this.model] ?? 30_720;
}
isValidChatCompletionModel(modelName = "") {
const validModels = [
"gemini-pro",
"gemini-1.0-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
];
return validModels.includes(modelName);
/**
* Fetches Gemini models from the Google Generative AI API
* @param {string} apiKey - The API key to use for the request
* @param {number} limit - The maximum number of models to fetch
* @param {string} pageToken - The page token to use for pagination
* @returns {Promise<[{id: string, name: string, contextWindow: number, experimental: boolean}]>} A promise that resolves to an array of Gemini models
*/
static async fetchModels(apiKey, limit = 1_000, pageToken = null) {
const url = new URL(
"https://generativelanguage.googleapis.com/v1beta/models"
);
url.searchParams.set("pageSize", limit);
url.searchParams.set("key", apiKey);
if (pageToken) url.searchParams.set("pageToken", pageToken);
return fetch(url.toString(), {
method: "GET",
headers: { "Content-Type": "application/json" },
})
.then((res) => res.json())
.then((data) => {
if (data.error) throw new Error(data.error.message);
return data.models ?? [];
})
.then((models) =>
models
.filter(
(model) => !model.displayName.toLowerCase().includes("tuning")
)
.filter((model) =>
model.supportedGenerationMethods.includes("generateContent")
) // Only generateContent is supported
.map((model) => {
return {
id: model.name.split("/").pop(),
name: model.displayName,
contextWindow: model.inputTokenLimit,
experimental: model.name.includes("exp"),
};
})
)
.catch((e) => {
console.error(`Gemini:getGeminiModels`, e.message);
return defaultGeminiModels;
});
}
/**
* Checks if a model is valid for chat completion (unused)
* @deprecated
* @param {string} modelName - The name of the model to check
* @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the model is valid
*/
async isValidChatCompletionModel(modelName = "") {
const models = await this.fetchModels(true);
return models.some((model) => model.id === modelName);
}
/**
* Generates appropriate content array for a message + attachments.
* @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
@ -218,11 +257,6 @@ class GeminiLLM {
}
async getChatCompletion(messages = [], _opts = {}) {
if (!this.isValidChatCompletionModel(this.model))
throw new Error(
`Gemini chat: ${this.model} is not valid for chat completion!`
);
const prompt = messages.find(
(chat) => chat.role === "USER_PROMPT"
)?.content;
@ -256,11 +290,6 @@ class GeminiLLM {
}
async streamGetChatCompletion(messages = [], _opts = {}) {
if (!this.isValidChatCompletionModel(this.model))
throw new Error(
`Gemini chat: ${this.model} is not valid for chat completion!`
);
const prompt = messages.find(
(chat) => chat.role === "USER_PROMPT"
)?.content;

View file

@ -7,6 +7,7 @@ const { ElevenLabsTTS } = require("../TextToSpeech/elevenLabs");
const { fetchNovitaModels } = require("../AiProviders/novita");
const { parseLMStudioBasePath } = require("../AiProviders/lmStudio");
const { parseNvidiaNimBasePath } = require("../AiProviders/nvidiaNim");
const { GeminiLLM } = require("../AiProviders/gemini");
const SUPPORT_CUSTOM_MODELS = [
"openai",
@ -28,6 +29,7 @@ const SUPPORT_CUSTOM_MODELS = [
"apipie",
"novita",
"xai",
"gemini",
];
async function getCustomModels(provider = "", apiKey = null, basePath = null) {
@ -73,6 +75,8 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) {
return await getXAIModels(apiKey);
case "nvidia-nim":
return await getNvidiaNimModels(basePath);
case "gemini":
return await getGeminiModels(apiKey);
default:
return { models: [], error: "Invalid provider for custom models" };
}
@ -572,6 +576,17 @@ async function getNvidiaNimModels(basePath = null) {
}
}
async function getGeminiModels(_apiKey = null) {
const apiKey =
_apiKey === true
? process.env.GEMINI_API_KEY
: _apiKey || process.env.GEMINI_API_KEY || null;
const models = await GeminiLLM.fetchModels(apiKey);
// Api Key was successful so lets save it for future uses
if (models.length > 0 && !!apiKey) process.env.GEMINI_API_KEY = apiKey;
return { models, error: null };
}
module.exports = {
getCustomModels,
};

View file

@ -52,7 +52,7 @@ const KEY_MAPPING = {
},
GeminiLLMModelPref: {
envKey: "GEMINI_LLM_MODEL_PREF",
checks: [isNotEmpty, validGeminiModel],
checks: [isNotEmpty],
},
GeminiSafetySetting: {
envKey: "GEMINI_SAFETY_SETTING",
@ -724,27 +724,6 @@ function supportedTranscriptionProvider(input = "") {
: `${input} is not a valid transcription model provider.`;
}
function validGeminiModel(input = "") {
const validModels = [
"gemini-pro",
"gemini-1.0-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
];
return validModels.includes(input)
? null
: `Invalid Model type. Must be one of ${validModels.join(", ")}.`;
}
function validGeminiSafetySetting(input = "") {
const validModes = [
"BLOCK_NONE",