From 6674e5aab8ce41170e1ca56405e5f3fbab992608 Mon Sep 17 00:00:00 2001
From: Sean Hatfield <seanhatfield5@gmail.com>
Date: Tue, 15 Oct 2024 15:24:44 -0700
Subject: [PATCH] Support free-form input for workspace model for providers
 with no `/models` endpoint (#2397)

* support generic openai workspace model

* Update UI for free form input for some providers

---------

Co-authored-by: Timothy Carambat <rambat1010@gmail.com>
---
 .../LLMSelection/AzureAiOptions/index.jsx     |  17 ---
 .../WorkspaceLLMSelection/index.jsx           | 107 ++++++++++++------
 server/utils/AiProviders/azureOpenAi/index.js |   4 +-
 server/utils/AiProviders/bedrock/index.js     |   2 +-
 4 files changed, 76 insertions(+), 54 deletions(-)

diff --git a/frontend/src/components/LLMSelection/AzureAiOptions/index.jsx b/frontend/src/components/LLMSelection/AzureAiOptions/index.jsx
index c8b29e20f..c04f9f3fd 100644
--- a/frontend/src/components/LLMSelection/AzureAiOptions/index.jsx
+++ b/frontend/src/components/LLMSelection/AzureAiOptions/index.jsx
@@ -71,23 +71,6 @@ export default function AzureAiOptions({ settings }) {
             </option>
           </select>
         </div>
-
-        <div className="flex flex-col w-60">
-          <label className="text-white text-sm font-semibold block mb-3">
-            Embedding Deployment Name
-          </label>
-          <input
-            type="text"
-            name="AzureOpenAiEmbeddingModelPref"
-            className="bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5"
-            placeholder="Azure OpenAI embedding model deployment name"
-            defaultValue={settings?.AzureOpenAiEmbeddingModelPref}
-            required={true}
-            autoComplete="off"
-            spellCheck={false}
-          />
-        </div>
-        <div className="flex-flex-col w-60"></div>
       </div>
     </div>
   );
diff --git a/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/index.jsx
index e065aff4a..7cd5653e3 100644
--- a/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/index.jsx
+++ b/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/index.jsx
@@ -8,15 +8,18 @@ import { useTranslation } from "react-i18next";
 import { Link } from "react-router-dom";
 import paths from "@/utils/paths";
 
-// Some providers can only be associated with a single model.
-// In that case there is no selection to be made so we can just move on.
-const NO_MODEL_SELECTION = [
-  "default",
-  "huggingface",
-  "generic-openai",
-  "bedrock",
-];
-const DISABLED_PROVIDERS = ["azure", "native"];
+// Some providers do not support model selection via /models.
+// In that case we allow the user to enter the model name manually and hope they
+// type it correctly.
+const FREE_FORM_LLM_SELECTION = ["bedrock", "azure", "generic-openai"];
+
+// Some providers do not support model selection via /models
+// and only have a fixed single-model they can use.
+const NO_MODEL_SELECTION = ["default", "huggingface"];
+
+// Some providers we just fully disable for ease of use.
+const DISABLED_PROVIDERS = ["native"];
+
 const LLM_DEFAULT = {
   name: "System default",
   value: "default",
@@ -65,8 +68,8 @@ export default function WorkspaceLLMSelection({
     );
     setFilteredLLMs(filtered);
   }, [LLMS, searchQuery, selectedLLM]);
-
   const selectedLLMObject = LLMS.find((llm) => llm.value === selectedLLM);
+
   return (
     <div className="border-b border-white/40 pb-8">
       <div className="flex flex-col">
@@ -155,30 +158,66 @@ export default function WorkspaceLLMSelection({
           </button>
         )}
       </div>
-      {NO_MODEL_SELECTION.includes(selectedLLM) ? (
-        <>
-          {selectedLLM !== "default" && (
-            <div className="w-full h-10 justify-center items-center flex mt-4">
-              <p className="text-sm font-base text-white text-opacity-60 text-center">
-                Multi-model support is not supported for this provider yet.
-                <br />
-                This workspace will use{" "}
-                <Link to={paths.settings.llmPreference()} className="underline">
-                  the model set for the system.
-                </Link>
-              </p>
-            </div>
-          )}
-        </>
-      ) : (
-        <div className="mt-4 flex flex-col gap-y-1">
-          <ChatModelSelection
-            provider={selectedLLM}
-            workspace={workspace}
-            setHasChanges={setHasChanges}
-          />
-        </div>
-      )}
+      <ModelSelector
+        selectedLLM={selectedLLM}
+        workspace={workspace}
+        setHasChanges={setHasChanges}
+      />
+    </div>
+  );
+}
+
+// TODO: Add this to agent selector as well as make generic component.
+function ModelSelector({ selectedLLM, workspace, setHasChanges }) {
+  if (NO_MODEL_SELECTION.includes(selectedLLM)) {
+    if (selectedLLM !== "default") {
+      return (
+        <div className="w-full h-10 justify-center items-center flex mt-4">
+          <p className="text-sm font-base text-white text-opacity-60 text-center">
+            Multi-model support is not supported for this provider yet.
+            <br />
+            This workspace will use{" "}
+            <Link to={paths.settings.llmPreference()} className="underline">
+              the model set for the system.
+            </Link>
+          </p>
+        </div>
+      );
+    }
+    return null;
+  }
+
+  if (FREE_FORM_LLM_SELECTION.includes(selectedLLM)) {
+    return (
+      <FreeFormLLMInput workspace={workspace} setHasChanges={setHasChanges} />
+    );
+  }
+
+  return (
+    <ChatModelSelection
+      provider={selectedLLM}
+      workspace={workspace}
+      setHasChanges={setHasChanges}
+    />
+  );
+}
+
+function FreeFormLLMInput({ workspace, setHasChanges }) {
+  const { t } = useTranslation();
+  return (
+    <div className="mt-4 flex flex-col gap-y-1">
+      <label className="block input-label">{t("chat.model.title")}</label>
+      <p className="text-white text-opacity-60 text-xs font-medium py-1.5">
+        {t("chat.model.description")}
+      </p>
+      <input
+        type="text"
+        name="chatModel"
+        defaultValue={workspace?.chatModel || ""}
+        onChange={() => setHasChanges(true)}
+        className="bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5"
+        placeholder="Enter model name exactly as referenced in the API (e.g., gpt-3.5-turbo)"
+      />
     </div>
   );
 }
diff --git a/server/utils/AiProviders/azureOpenAi/index.js b/server/utils/AiProviders/azureOpenAi/index.js
index 2a293d053..98c6d5153 100644
--- a/server/utils/AiProviders/azureOpenAi/index.js
+++ b/server/utils/AiProviders/azureOpenAi/index.js
@@ -5,7 +5,7 @@ const {
 } = require("../../helpers/chat/responses");
 
 class AzureOpenAiLLM {
-  constructor(embedder = null, _modelPreference = null) {
+  constructor(embedder = null, modelPreference = null) {
     const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
     if (!process.env.AZURE_OPENAI_ENDPOINT)
       throw new Error("No Azure API endpoint was set.");
@@ -16,7 +16,7 @@ class AzureOpenAiLLM {
       process.env.AZURE_OPENAI_ENDPOINT,
       new AzureKeyCredential(process.env.AZURE_OPENAI_KEY)
     );
-    this.model = process.env.OPEN_MODEL_PREF;
+    this.model = modelPreference ?? process.env.OPEN_MODEL_PREF;
     this.limits = {
       history: this.promptWindowLimit() * 0.15,
       system: this.promptWindowLimit() * 0.15,
diff --git a/server/utils/AiProviders/bedrock/index.js b/server/utils/AiProviders/bedrock/index.js
index ebff7ea29..28d0c2ce3 100644
--- a/server/utils/AiProviders/bedrock/index.js
+++ b/server/utils/AiProviders/bedrock/index.js
@@ -32,7 +32,7 @@ class AWSBedrockLLM {
   #bedrockClient({ temperature = 0.7 }) {
     const { ChatBedrockConverse } = require("@langchain/aws");
     return new ChatBedrockConverse({
-      model: process.env.AWS_BEDROCK_LLM_MODEL_PREFERENCE,
+      model: this.model,
       region: process.env.AWS_BEDROCK_LLM_REGION,
       credentials: {
         accessKeyId: process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID,