From 28eba636e941e8c6d2a54e5209b0dbcc2c892480 Mon Sep 17 00:00:00 2001 From: Timothy Carambat <rambat1010@gmail.com> Date: Mon, 20 May 2024 13:17:00 -0500 Subject: [PATCH] Allow setting of safety thresholds for Gemini (#1466) * Allow setting of safety thresholds for Gemini * linting --- .../LLMSelection/GeminiLLMOptions/index.jsx | 60 +++++++++++++------ server/models/systemSettings.js | 2 + server/utils/AiProviders/gemini/index.js | 38 ++++++++++++ server/utils/helpers/updateENV.js | 16 +++++ 4 files changed, 97 insertions(+), 19 deletions(-) diff --git a/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx b/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx index 8cb513f31..d2846704d 100644 --- a/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx +++ b/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx @@ -19,25 +19,47 @@ export default function GeminiLLMOptions({ settings }) { </div> {!settings?.credentialsOnly && ( - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - Chat Model Selection - </label> - <select - name="GeminiLLMModelPref" - defaultValue={settings?.GeminiLLMModelPref || "gemini-pro"} - required={true} - className="bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" - > - {["gemini-pro", "gemini-1.5-pro-latest"].map((model) => { - return ( - <option key={model} value={model}> - {model} - </option> - ); - })} - </select> - </div> + <> + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-4"> + Chat Model Selection + </label> + <select + name="GeminiLLMModelPref" + defaultValue={settings?.GeminiLLMModelPref || "gemini-pro"} + required={true} + className="bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" + > + {["gemini-pro", "gemini-1.5-pro-latest"].map((model) => { + return ( + <option key={model} value={model}> + {model} + </option> + ); + })} + </select> + </div> + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-4"> + Safety Setting + </label> + <select + name="GeminiSafetySetting" + defaultValue={ + settings?.GeminiSafetySetting || "BLOCK_MEDIUM_AND_ABOVE" + } + required={true} + className="bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" + > + <option value="BLOCK_NONE">None</option> + <option value="BLOCK_ONLY_HIGH">Block few</option> + <option value="BLOCK_MEDIUM_AND_ABOVE"> + Block some (default) + </option> + <option value="BLOCK_LOW_AND_ABOVE">Block most</option> + </select> + </div> + </> )} </div> </div> diff --git a/server/models/systemSettings.js b/server/models/systemSettings.js index a5bb6a23c..70913fd9d 100644 --- a/server/models/systemSettings.js +++ b/server/models/systemSettings.js @@ -354,6 +354,8 @@ const SystemSettings = { // Gemini Keys GeminiLLMApiKey: !!process.env.GEMINI_API_KEY, GeminiLLMModelPref: process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro", + GeminiSafetySetting: + process.env.GEMINI_SAFETY_SETTING || "BLOCK_MEDIUM_AND_ABOVE", // LMStudio Keys LMStudioBasePath: process.env.LMSTUDIO_BASE_PATH, diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js index 3dd307aff..0c2cc7697 100644 --- a/server/utils/AiProviders/gemini/index.js +++ b/server/utils/AiProviders/gemini/index.js @@ -29,6 +29,7 @@ class GeminiLLM { this.embedder = embedder ?? new NativeEmbedder(); this.defaultTemp = 0.7; // not used for Gemini + this.safetyThreshold = this.#fetchSafetyThreshold(); } #appendContext(contextTexts = []) { @@ -43,6 +44,41 @@ class GeminiLLM { ); } + // BLOCK_NONE can be a special candidate for some fields + // https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-attributes#how_to_remove_automated_response_blocking_for_select_safety_attributes + // so if you are wondering why BLOCK_NONE still failed, the link above will explain why. + #fetchSafetyThreshold() { + const threshold = + process.env.GEMINI_SAFETY_SETTING ?? "BLOCK_MEDIUM_AND_ABOVE"; + const safetyThresholds = [ + "BLOCK_NONE", + "BLOCK_ONLY_HIGH", + "BLOCK_MEDIUM_AND_ABOVE", + "BLOCK_LOW_AND_ABOVE", + ]; + return safetyThresholds.includes(threshold) + ? threshold + : "BLOCK_MEDIUM_AND_ABOVE"; + } + + #safetySettings() { + return [ + { + category: "HARM_CATEGORY_HATE_SPEECH", + threshold: this.safetyThreshold, + }, + { + category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + threshold: this.safetyThreshold, + }, + { category: "HARM_CATEGORY_HARASSMENT", threshold: this.safetyThreshold }, + { + category: "HARM_CATEGORY_DANGEROUS_CONTENT", + threshold: this.safetyThreshold, + }, + ]; + } + streamingEnabled() { return "streamGetChatCompletion" in this; } @@ -143,6 +179,7 @@ class GeminiLLM { )?.content; const chatThread = this.gemini.startChat({ history: this.formatMessages(messages), + safetySettings: this.#safetySettings(), }); const result = await chatThread.sendMessage(prompt); const response = result.response; @@ -164,6 +201,7 @@ class GeminiLLM { )?.content; const chatThread = this.gemini.startChat({ history: this.formatMessages(messages), + safetySettings: this.#safetySettings(), }); const responseStream = await chatThread.sendMessageStream(prompt); if (!responseStream.stream) diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js index 401541634..c8811c9de 100644 --- a/server/utils/helpers/updateENV.js +++ b/server/utils/helpers/updateENV.js @@ -52,6 +52,10 @@ const KEY_MAPPING = { envKey: "GEMINI_LLM_MODEL_PREF", checks: [isNotEmpty, validGeminiModel], }, + GeminiSafetySetting: { + envKey: "GEMINI_SAFETY_SETTING", + checks: [validGeminiSafetySetting], + }, // LMStudio Settings LMStudioBasePath: { @@ -528,6 +532,18 @@ function validGeminiModel(input = "") { : `Invalid Model type. Must be one of ${validModels.join(", ")}.`; } +function validGeminiSafetySetting(input = "") { + const validModes = [ + "BLOCK_NONE", + "BLOCK_ONLY_HIGH", + "BLOCK_MEDIUM_AND_ABOVE", + "BLOCK_LOW_AND_ABOVE", + ]; + return validModes.includes(input) + ? null + : `Invalid Safety setting. Must be one of ${validModes.join(", ")}.`; +} + function validAnthropicModel(input = "") { const validModels = [ "claude-instant-1.2",