Refactor Gemini to use OpenAI interface API ()

* Refactor Gemini to use OpenAI interface API

* add TODO

* handle errors better (gemini)

* remove unused code
This commit is contained in:
Timothy Carambat 2025-04-07 17:18:31 -07:00 committed by GitHub
parent 699684e301
commit 1b59295f89
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 290 additions and 257 deletions
docker
frontend/src
components/LLMSelection/GeminiLLMOptions
hooks
pages/WorkspaceSettings/AgentConfig/AgentLLMSelection
server
.env.example
models
utils
AiProviders/gemini
EmbeddingEngines/gemini
agents

View file

@ -15,7 +15,7 @@ GID='1000'
# LLM_PROVIDER='gemini'
# GEMINI_API_KEY=
# GEMINI_LLM_MODEL_PREF='gemini-pro'
# GEMINI_LLM_MODEL_PREF='gemini-2.0-flash-lite'
# LLM_PROVIDER='azure'
# AZURE_OPENAI_ENDPOINT=

View file

@ -29,6 +29,11 @@ export default function GeminiLLMOptions({ settings }) {
{!settings?.credentialsOnly && (
<>
<GeminiModelSelection apiKey={geminiApiKey} settings={settings} />
{/*
Safety setting is not supported for Gemini yet due to the openai compatible Gemini API.
We are not using the generativeAPI endpoint and therefore cannot set the safety threshold.
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Safety Setting
@ -48,7 +53,7 @@ export default function GeminiLLMOptions({ settings }) {
</option>
<option value="BLOCK_LOW_AND_ABOVE">Block most</option>
</select>
</div>
</div> */}
</>
)}
</div>

View file

@ -10,21 +10,7 @@ export const DISABLED_PROVIDERS = [
];
const PROVIDER_DEFAULT_MODELS = {
openai: [],
gemini: [
"gemini-pro",
"gemini-1.0-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
],
gemini: [],
anthropic: [],
azure: [],
lmstudio: [],

View file

@ -30,10 +30,10 @@ const ENABLED_PROVIDERS = [
"apipie",
"xai",
"nvidia-nim",
"gemini",
// TODO: More agent support.
// "cohere", // Has tool calling and will need to build explicit support
// "huggingface" // Can be done but already has issues with no-chat templated. Needs to be tested.
// "gemini", // Too rate limited and broken in several ways to use for agents.
];
const WARN_PERFORMANCE = [
"lmstudio",

View file

@ -12,7 +12,7 @@ SIG_SALT='salt' # Please generate random string at least 32 chars long.
# LLM_PROVIDER='gemini'
# GEMINI_API_KEY=
# GEMINI_LLM_MODEL_PREF='gemini-pro'
# GEMINI_LLM_MODEL_PREF='gemini-2.0-flash-lite'
# LLM_PROVIDER='azure'
# AZURE_OPENAI_ENDPOINT=

View file

@ -450,7 +450,8 @@ const SystemSettings = {
// Gemini Keys
GeminiLLMApiKey: !!process.env.GEMINI_API_KEY,
GeminiLLMModelPref: process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro",
GeminiLLMModelPref:
process.env.GEMINI_LLM_MODEL_PREF || "gemini-2.0-flash-lite",
GeminiSafetySetting:
process.env.GEMINI_SAFETY_SETTING || "BLOCK_MEDIUM_AND_ABOVE",

View file

@ -5,9 +5,8 @@ const {
LLMPerformanceMonitor,
} = require("../../helpers/chat/LLMPerformanceMonitor");
const {
writeResponseChunk,
clientAbortedHandler,
formatChatHistory,
handleDefaultStreamResponseV2,
} = require("../../helpers/chat/responses");
const { MODEL_MAP } = require("../modelMap");
const { defaultGeminiModels, v1BetaModels } = require("./defaultModels");
@ -18,22 +17,31 @@ const cacheFolder = path.resolve(
: path.resolve(__dirname, `../../../storage/models/gemini`)
);
const NO_SYSTEM_PROMPT_MODELS = [
"gemma-3-1b-it",
"gemma-3-4b-it",
"gemma-3-12b-it",
"gemma-3-27b-it",
];
class GeminiLLM {
constructor(embedder = null, modelPreference = null) {
if (!process.env.GEMINI_API_KEY)
throw new Error("No Gemini API key was set.");
// Docs: https://ai.google.dev/tutorials/node_quickstart
const { GoogleGenerativeAI } = require("@google/generative-ai");
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY);
const { OpenAI: OpenAIApi } = require("openai");
this.model =
modelPreference || process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro";
modelPreference ||
process.env.GEMINI_LLM_MODEL_PREF ||
"gemini-2.0-flash-lite";
const isExperimental = this.isExperimentalModel(this.model);
this.gemini = genAI.getGenerativeModel(
{ model: this.model },
{ apiVersion: isExperimental ? "v1beta" : "v1" }
);
this.openai = new OpenAIApi({
apiKey: process.env.GEMINI_API_KEY,
// Even models that are v1 in gemini API can be used with v1beta/openai/ endpoint and nobody knows why.
baseURL: "https://generativelanguage.googleapis.com/v1beta/openai/",
});
this.limits = {
history: this.promptWindowLimit() * 0.15,
system: this.promptWindowLimit() * 0.15,
@ -41,8 +49,7 @@ class GeminiLLM {
};
this.embedder = embedder ?? new NativeEmbedder();
this.defaultTemp = 0.7; // not used for Gemini
this.safetyThreshold = this.#fetchSafetyThreshold();
this.defaultTemp = 0.7;
if (!fs.existsSync(cacheFolder))
fs.mkdirSync(cacheFolder, { recursive: true });
@ -53,6 +60,16 @@ class GeminiLLM {
);
}
/**
* Checks if the model supports system prompts
* This is a static list of models that are known to not support system prompts
* since this information is not available in the API model response.
* @returns {boolean}
*/
get supportsSystemPrompt() {
return !NO_SYSTEM_PROMPT_MODELS.includes(this.model);
}
#log(text, ...args) {
console.log(`\x1b[32m[GeminiLLM]\x1b[0m ${text}`, ...args);
}
@ -82,41 +99,6 @@ class GeminiLLM {
);
}
// BLOCK_NONE can be a special candidate for some fields
// https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-attributes#how_to_remove_automated_response_blocking_for_select_safety_attributes
// so if you are wondering why BLOCK_NONE still failed, the link above will explain why.
#fetchSafetyThreshold() {
const threshold =
process.env.GEMINI_SAFETY_SETTING ?? "BLOCK_MEDIUM_AND_ABOVE";
const safetyThresholds = [
"BLOCK_NONE",
"BLOCK_ONLY_HIGH",
"BLOCK_MEDIUM_AND_ABOVE",
"BLOCK_LOW_AND_ABOVE",
];
return safetyThresholds.includes(threshold)
? threshold
: "BLOCK_MEDIUM_AND_ABOVE";
}
#safetySettings() {
return [
{
category: "HARM_CATEGORY_HATE_SPEECH",
threshold: this.safetyThreshold,
},
{
category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
threshold: this.safetyThreshold,
},
{ category: "HARM_CATEGORY_HARASSMENT", threshold: this.safetyThreshold },
{
category: "HARM_CATEGORY_DANGEROUS_CONTENT",
threshold: this.safetyThreshold,
},
];
}
streamingEnabled() {
return "streamGetChatCompletion" in this;
}
@ -336,147 +318,114 @@ class GeminiLLM {
* @returns {string|object[]}
*/
#generateContent({ userPrompt, attachments = [] }) {
if (!attachments.length) {
return userPrompt;
}
if (!attachments.length) return userPrompt;
const content = [{ text: userPrompt }];
const content = [{ type: "text", text: userPrompt }];
for (let attachment of attachments) {
content.push({
inlineData: {
data: attachment.contentString.split("base64,")[1],
mimeType: attachment.mime,
type: "image_url",
image_url: {
url: attachment.contentString,
detail: "high",
},
});
}
return content.flat();
}
/**
* Construct the user prompt for this model.
* @param {{attachments: import("../../helpers").Attachment[]}} param0
* @returns
*/
constructPrompt({
systemPrompt = "",
contextTexts = [],
chatHistory = [],
userPrompt = "",
attachments = [],
attachments = [], // This is the specific attachment for only this prompt
}) {
const prompt = {
role: "system",
content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
};
let prompt = [];
if (this.supportsSystemPrompt) {
prompt.push({
role: "system",
content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
});
} else {
this.#log(
`${this.model} - does not support system prompts - emulating...`
);
prompt.push(
{
role: "user",
content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
},
{
role: "assistant",
content: "Okay.",
}
);
}
return [
prompt,
{ role: "assistant", content: "Okay." },
...prompt,
...formatChatHistory(chatHistory, this.#generateContent),
{
role: "USER_PROMPT",
role: "user",
content: this.#generateContent({ userPrompt, attachments }),
},
];
}
// This will take an OpenAi format message array and only pluck valid roles from it.
formatMessages(messages = []) {
// Gemini roles are either user || model.
// and all "content" is relabeled to "parts"
const allMessages = messages
.map((message) => {
if (message.role === "system")
return { role: "user", parts: [{ text: message.content }] };
async getChatCompletion(messages = null, { temperature = 0.7 }) {
const result = await LLMPerformanceMonitor.measureAsyncFunction(
this.openai.chat.completions
.create({
model: this.model,
messages,
temperature: temperature,
})
.catch((e) => {
console.error(e);
throw new Error(e.message);
})
);
if (message.role === "user") {
// If the content is an array - then we have already formatted the context so return it directly.
if (Array.isArray(message.content))
return { role: "user", parts: message.content };
// Otherwise, this was a regular user message with no attachments
// so we need to format it for Gemini
return { role: "user", parts: [{ text: message.content }] };
}
if (message.role === "assistant")
return { role: "model", parts: [{ text: message.content }] };
return null;
})
.filter((msg) => !!msg);
// Specifically, Google cannot have the last sent message be from a user with no assistant reply
// otherwise it will crash. So if the last item is from the user, it was not completed so pop it off
// the history.
if (
allMessages.length > 0 &&
allMessages[allMessages.length - 1].role === "user"
!result.output.hasOwnProperty("choices") ||
result.output.choices.length === 0
)
allMessages.pop();
// Validate that after every user message, there is a model message
// sometimes when using gemini we try to compress messages in order to retain as
// much context as possible but this may mess up the order of the messages that the gemini model expects
// we do this check to work around the edge case where 2 user prompts may be next to each other, in the message array
for (let i = 0; i < allMessages.length; i++) {
if (
allMessages[i].role === "user" &&
i < allMessages.length - 1 &&
allMessages[i + 1].role !== "model"
) {
allMessages.splice(i + 1, 0, {
role: "model",
parts: [{ text: "Okay." }],
});
}
}
return allMessages;
}
async getChatCompletion(messages = [], _opts = {}) {
const prompt = messages.find(
(chat) => chat.role === "USER_PROMPT"
)?.content;
const chatThread = this.gemini.startChat({
history: this.formatMessages(messages),
safetySettings: this.#safetySettings(),
});
const { output: result, duration } =
await LLMPerformanceMonitor.measureAsyncFunction(
chatThread.sendMessage(prompt)
);
const responseText = result.response.text();
if (!responseText) throw new Error("Gemini: No response could be parsed.");
const promptTokens = LLMPerformanceMonitor.countTokens(messages);
const completionTokens = LLMPerformanceMonitor.countTokens([
{ content: responseText },
]);
return null;
return {
textResponse: responseText,
textResponse: result.output.choices[0].message.content,
metrics: {
prompt_tokens: promptTokens,
completion_tokens: completionTokens,
total_tokens: promptTokens + completionTokens,
outputTps: (promptTokens + completionTokens) / duration,
duration,
prompt_tokens: result.output.usage.prompt_tokens || 0,
completion_tokens: result.output.usage.completion_tokens || 0,
total_tokens: result.output.usage.total_tokens || 0,
outputTps: result.output.usage.completion_tokens / result.duration,
duration: result.duration,
},
};
}
async streamGetChatCompletion(messages = [], _opts = {}) {
const prompt = messages.find(
(chat) => chat.role === "USER_PROMPT"
)?.content;
const chatThread = this.gemini.startChat({
history: this.formatMessages(messages),
safetySettings: this.#safetySettings(),
});
const responseStream = await LLMPerformanceMonitor.measureStream(
(await chatThread.sendMessageStream(prompt)).stream,
messages
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
this.openai.chat.completions.create({
model: this.model,
stream: true,
messages,
temperature: temperature,
}),
messages,
true
);
if (!responseStream)
throw new Error("Could not stream response stream from Gemini.");
return responseStream;
return measuredStreamRequest;
}
handleStream(response, stream, responseProps) {
return handleDefaultStreamResponseV2(response, stream, responseProps);
}
async compressMessages(promptArgs = {}, rawHistory = []) {
@ -485,81 +434,6 @@ class GeminiLLM {
return await messageArrayCompressor(this, messageArray, rawHistory);
}
handleStream(response, stream, responseProps) {
const { uuid = uuidv4(), sources = [] } = responseProps;
// Usage is not available for Gemini streams
// so we need to calculate the completion tokens manually
// because 1 chunk != 1 token in gemini responses and it buffers
// many tokens before sending them to the client as a "chunk"
return new Promise(async (resolve) => {
let fullText = "";
// Establish listener to early-abort a streaming response
// in case things go sideways or the user does not like the response.
// We preserve the generated text but continue as if chat was completed
// to preserve previously generated content.
const handleAbort = () => {
stream?.endMeasurement({
completion_tokens: LLMPerformanceMonitor.countTokens([
{ content: fullText },
]),
});
clientAbortedHandler(resolve, fullText);
};
response.on("close", handleAbort);
for await (const chunk of stream) {
let chunkText;
try {
// Due to content sensitivity we cannot always get the function .text();
// https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-attributes#gemini-TASK-samples-nodejs
// and it is not possible to unblock or disable this safety protocol without being allowlisted by Google.
chunkText = chunk.text();
} catch (e) {
chunkText = e.message;
writeResponseChunk(response, {
uuid,
sources: [],
type: "abort",
textResponse: null,
close: true,
error: e.message,
});
stream?.endMeasurement({ completion_tokens: 0 });
resolve(e.message);
return;
}
fullText += chunkText;
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: chunk.text(),
close: false,
error: false,
});
}
writeResponseChunk(response, {
uuid,
sources,
type: "textResponseChunk",
textResponse: "",
close: true,
error: false,
});
response.removeListener("close", handleAbort);
stream?.endMeasurement({
completion_tokens: LLMPerformanceMonitor.countTokens([
{ content: fullText },
]),
});
resolve(fullText);
});
}
// Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
async embedTextInput(textInput) {
return await this.embedder.embedTextInput(textInput);
@ -571,4 +445,5 @@ class GeminiLLM {
module.exports = {
GeminiLLM,
NO_SYSTEM_PROMPT_MODELS,
};

View file

@ -2,6 +2,8 @@ class GeminiEmbedder {
constructor() {
if (!process.env.GEMINI_EMBEDDING_API_KEY)
throw new Error("No Gemini API key was set.");
// TODO: Deprecate this and use OpenAI interface instead - after which, remove the @google/generative-ai dependency
const { GoogleGenerativeAI } = require("@google/generative-ai");
const genAI = new GoogleGenerativeAI(process.env.GEMINI_EMBEDDING_API_KEY);
this.model = process.env.EMBEDDING_MODEL_PREF || "text-embedding-004";

View file

@ -491,9 +491,7 @@ Only return the role.
// and remove the @ from the response
const { result } = await provider.complete(messages);
const name = result?.replace(/^@/g, "");
if (this.agents.get(name)) {
return name;
}
if (this.agents.get(name)) return name;
// if the name is not in the nodes, return a random node
return availableNodes[Math.floor(Math.random() * availableNodes.length)];
@ -797,6 +795,8 @@ ${this.getHistory({ to: route.to })
return new Providers.NovitaProvider({ model: config.model });
case "ppio":
return new Providers.PPIOProvider({ model: config.model });
case "gemini":
return new Providers.GeminiProvider({ model: config.model });
default:
throw new Error(
`Unknown provider: ${config.provider}. Please use a valid provider.`

View file

@ -171,6 +171,14 @@ class Provider {
apiKey: process.env.PPIO_API_KEY ?? null,
...config,
});
case "gemini":
return new ChatOpenAI({
configuration: {
baseURL: "https://generativelanguage.googleapis.com/v1beta/openai/",
},
apiKey: process.env.GEMINI_API_KEY ?? null,
...config,
});
// OSS Model Runners
// case "anythingllm_ollama":

View file

@ -0,0 +1,154 @@
const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
const {
NO_SYSTEM_PROMPT_MODELS,
} = require("../../../AiProviders/gemini/index.js");
const { APIError } = require("../error.js");
/**
* The agent provider for the Gemini provider.
* We wrap Gemini in UnTooled because its tool-calling is not supported via the dedicated OpenAI API.
*/
class GeminiProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(config = {}) {
const { model = "gemini-2.0-flash-lite" } = config;
super();
const client = new OpenAI({
baseURL: "https://generativelanguage.googleapis.com/v1beta/openai/",
apiKey: process.env.GEMINI_API_KEY,
maxRetries: 0,
});
this._client = client;
this.model = model;
this.verbose = true;
}
get client() {
return this._client;
}
/**
* Format the messages to the format required by the Gemini API since some models do not support system prompts.
* @see {NO_SYSTEM_PROMPT_MODELS}
* @param {import("openai").OpenAI.ChatCompletionMessage[]} messages
* @returns {import("openai").OpenAI.ChatCompletionMessage[]}
*/
formatMessages(messages) {
if (!NO_SYSTEM_PROMPT_MODELS.includes(this.model)) return messages;
// Replace the system message with a user/assistant message pair
const formattedMessages = [];
for (const message of messages) {
if (message.role === "system") {
formattedMessages.push({
role: "user",
content: message.content,
});
formattedMessages.push({
role: "assistant",
content: "Okay, I'll follow your instructions.",
});
continue;
}
formattedMessages.push(message);
}
return formattedMessages;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
messages: this.cleanMsgs(this.formatMessages(messages)),
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("Gemini chat: No results!");
if (result.choices.length === 0)
throw new Error("Gemini chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = []) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
this.cleanMsgs(this.formatMessages(messages)),
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.chat.completions.create({
model: this.model,
messages: this.cleanMsgs(this.formatMessages(messages)),
});
completion = response.choices[0].message;
}
// The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
// from calling the exact same function over and over in a loop within a single chat exchange
// _but_ we should enable it to call previously used tools in a new chat interaction.
this.deduplicator.reset("runs");
return {
result: completion.content,
cost: 0,
};
} catch (error) {
throw new APIError(
error?.message
? `${this.constructor.name} encountered an error while executing the request: ${error.message}`
: "There was an error with the Gemini provider executing the request"
);
}
}
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
*/
getCost(_usage) {
return 0;
}
}
module.exports = GeminiProvider;

View file

@ -21,6 +21,7 @@ const XAIProvider = require("./xai.js");
const NovitaProvider = require("./novita.js");
const NvidiaNimProvider = require("./nvidiaNim.js");
const PPIOProvider = require("./ppio.js");
const GeminiProvider = require("./gemini.js");
module.exports = {
OpenAIProvider,
@ -46,4 +47,5 @@ module.exports = {
NovitaProvider,
NvidiaNimProvider,
PPIOProvider,
GeminiProvider,
};

View file

@ -115,10 +115,6 @@ class AgentHandler {
"LocalAI must have a valid base path to use for the api."
);
break;
case "gemini":
if (!process.env.GEMINI_API_KEY)
throw new Error("Gemini API key must be provided to use agents.");
break;
case "openrouter":
if (!process.env.OPENROUTER_API_KEY)
throw new Error("OpenRouter API key must be provided to use agents.");
@ -189,6 +185,10 @@ class AgentHandler {
if (!process.env.PPIO_API_KEY)
throw new Error("PPIO API Key must be provided to use agents.");
break;
case "gemini":
if (!process.env.GEMINI_API_KEY)
throw new Error("Gemini API key must be provided to use agents.");
break;
default:
throw new Error(
@ -224,8 +224,6 @@ class AgentHandler {
return null;
case "koboldcpp":
return process.env.KOBOLD_CPP_MODEL_PREF ?? null;
case "gemini":
return process.env.GEMINI_MODEL_PREF ?? "gemini-pro";
case "localai":
return process.env.LOCAL_AI_MODEL_PREF ?? null;
case "openrouter":
@ -256,6 +254,8 @@ class AgentHandler {
return process.env.NVIDIA_NIM_LLM_MODEL_PREF ?? null;
case "ppio":
return process.env.PPIO_MODEL_PREF ?? "qwen/qwen2.5-32b-instruct";
case "gemini":
return process.env.GEMINI_LLM_MODEL_PREF ?? "gemini-2.0-flash-lite";
default:
return null;
}