From 28eba636e941e8c6d2a54e5209b0dbcc2c892480 Mon Sep 17 00:00:00 2001
From: Timothy Carambat <rambat1010@gmail.com>
Date: Mon, 20 May 2024 13:17:00 -0500
Subject: [PATCH] Allow setting of safety thresholds for Gemini (#1466)

* Allow setting of safety thresholds for Gemini

* linting
---
 .../LLMSelection/GeminiLLMOptions/index.jsx   | 60 +++++++++++++------
 server/models/systemSettings.js               |  2 +
 server/utils/AiProviders/gemini/index.js      | 38 ++++++++++++
 server/utils/helpers/updateENV.js             | 16 +++++
 4 files changed, 97 insertions(+), 19 deletions(-)

diff --git a/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx b/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx
index 8cb513f31..d2846704d 100644
--- a/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx
+++ b/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx
@@ -19,25 +19,47 @@ export default function GeminiLLMOptions({ settings }) {
         </div>
 
         {!settings?.credentialsOnly && (
-          <div className="flex flex-col w-60">
-            <label className="text-white text-sm font-semibold block mb-4">
-              Chat Model Selection
-            </label>
-            <select
-              name="GeminiLLMModelPref"
-              defaultValue={settings?.GeminiLLMModelPref || "gemini-pro"}
-              required={true}
-              className="bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
-            >
-              {["gemini-pro", "gemini-1.5-pro-latest"].map((model) => {
-                return (
-                  <option key={model} value={model}>
-                    {model}
-                  </option>
-                );
-              })}
-            </select>
-          </div>
+          <>
+            <div className="flex flex-col w-60">
+              <label className="text-white text-sm font-semibold block mb-4">
+                Chat Model Selection
+              </label>
+              <select
+                name="GeminiLLMModelPref"
+                defaultValue={settings?.GeminiLLMModelPref || "gemini-pro"}
+                required={true}
+                className="bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
+              >
+                {["gemini-pro", "gemini-1.5-pro-latest"].map((model) => {
+                  return (
+                    <option key={model} value={model}>
+                      {model}
+                    </option>
+                  );
+                })}
+              </select>
+            </div>
+            <div className="flex flex-col w-60">
+              <label className="text-white text-sm font-semibold block mb-4">
+                Safety Setting
+              </label>
+              <select
+                name="GeminiSafetySetting"
+                defaultValue={
+                  settings?.GeminiSafetySetting || "BLOCK_MEDIUM_AND_ABOVE"
+                }
+                required={true}
+                className="bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
+              >
+                <option value="BLOCK_NONE">None</option>
+                <option value="BLOCK_ONLY_HIGH">Block few</option>
+                <option value="BLOCK_MEDIUM_AND_ABOVE">
+                  Block some (default)
+                </option>
+                <option value="BLOCK_LOW_AND_ABOVE">Block most</option>
+              </select>
+            </div>
+          </>
         )}
       </div>
     </div>
diff --git a/server/models/systemSettings.js b/server/models/systemSettings.js
index a5bb6a23c..70913fd9d 100644
--- a/server/models/systemSettings.js
+++ b/server/models/systemSettings.js
@@ -354,6 +354,8 @@ const SystemSettings = {
       // Gemini Keys
       GeminiLLMApiKey: !!process.env.GEMINI_API_KEY,
       GeminiLLMModelPref: process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro",
+      GeminiSafetySetting:
+        process.env.GEMINI_SAFETY_SETTING || "BLOCK_MEDIUM_AND_ABOVE",
 
       // LMStudio Keys
       LMStudioBasePath: process.env.LMSTUDIO_BASE_PATH,
diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js
index 3dd307aff..0c2cc7697 100644
--- a/server/utils/AiProviders/gemini/index.js
+++ b/server/utils/AiProviders/gemini/index.js
@@ -29,6 +29,7 @@ class GeminiLLM {
 
     this.embedder = embedder ?? new NativeEmbedder();
     this.defaultTemp = 0.7; // not used for Gemini
+    this.safetyThreshold = this.#fetchSafetyThreshold();
   }
 
   #appendContext(contextTexts = []) {
@@ -43,6 +44,41 @@ class GeminiLLM {
     );
   }
 
+  // BLOCK_NONE can be a special candidate for some fields
+  // https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-attributes#how_to_remove_automated_response_blocking_for_select_safety_attributes
+  // so if you are wondering why BLOCK_NONE still failed, the link above will explain why.
+  #fetchSafetyThreshold() {
+    const threshold =
+      process.env.GEMINI_SAFETY_SETTING ?? "BLOCK_MEDIUM_AND_ABOVE";
+    const safetyThresholds = [
+      "BLOCK_NONE",
+      "BLOCK_ONLY_HIGH",
+      "BLOCK_MEDIUM_AND_ABOVE",
+      "BLOCK_LOW_AND_ABOVE",
+    ];
+    return safetyThresholds.includes(threshold)
+      ? threshold
+      : "BLOCK_MEDIUM_AND_ABOVE";
+  }
+
+  #safetySettings() {
+    return [
+      {
+        category: "HARM_CATEGORY_HATE_SPEECH",
+        threshold: this.safetyThreshold,
+      },
+      {
+        category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
+        threshold: this.safetyThreshold,
+      },
+      { category: "HARM_CATEGORY_HARASSMENT", threshold: this.safetyThreshold },
+      {
+        category: "HARM_CATEGORY_DANGEROUS_CONTENT",
+        threshold: this.safetyThreshold,
+      },
+    ];
+  }
+
   streamingEnabled() {
     return "streamGetChatCompletion" in this;
   }
@@ -143,6 +179,7 @@ class GeminiLLM {
     )?.content;
     const chatThread = this.gemini.startChat({
       history: this.formatMessages(messages),
+      safetySettings: this.#safetySettings(),
     });
     const result = await chatThread.sendMessage(prompt);
     const response = result.response;
@@ -164,6 +201,7 @@ class GeminiLLM {
     )?.content;
     const chatThread = this.gemini.startChat({
       history: this.formatMessages(messages),
+      safetySettings: this.#safetySettings(),
     });
     const responseStream = await chatThread.sendMessageStream(prompt);
     if (!responseStream.stream)
diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js
index 401541634..c8811c9de 100644
--- a/server/utils/helpers/updateENV.js
+++ b/server/utils/helpers/updateENV.js
@@ -52,6 +52,10 @@ const KEY_MAPPING = {
     envKey: "GEMINI_LLM_MODEL_PREF",
     checks: [isNotEmpty, validGeminiModel],
   },
+  GeminiSafetySetting: {
+    envKey: "GEMINI_SAFETY_SETTING",
+    checks: [validGeminiSafetySetting],
+  },
 
   // LMStudio Settings
   LMStudioBasePath: {
@@ -528,6 +532,18 @@ function validGeminiModel(input = "") {
     : `Invalid Model type. Must be one of ${validModels.join(", ")}.`;
 }
 
+function validGeminiSafetySetting(input = "") {
+  const validModes = [
+    "BLOCK_NONE",
+    "BLOCK_ONLY_HIGH",
+    "BLOCK_MEDIUM_AND_ABOVE",
+    "BLOCK_LOW_AND_ABOVE",
+  ];
+  return validModes.includes(input)
+    ? null
+    : `Invalid Safety setting. Must be one of ${validModes.join(", ")}.`;
+}
+
 function validAnthropicModel(input = "") {
   const validModels = [
     "claude-instant-1.2",