From 88d4808c5247a4f336ac73178eacc1bfb20bc1f8 Mon Sep 17 00:00:00 2001
From: Timothy Carambat <rambat1010@gmail.com>
Date: Mon, 6 Nov 2023 16:49:29 -0800
Subject: [PATCH] 315 show citations based on relevancy score (#316)

* settings for similarity score threshold and prisma schema updated

* prisma schema migration for adding similarityScore setting

* WIP

* Min score default change

* added similarityThreshold checking for all vectordb providers

* linting

---------

Co-authored-by: shatfield4 <seanhatfield5@gmail.com>
---
 .../Modals/MangeWorkspace/Settings/index.jsx  | 36 +++++++++++++++++++
 server/models/workspace.js                    |  1 +
 .../20231101001441_init/migration.sql         |  2 ++
 server/prisma/schema.prisma                   | 23 ++++++------
 server/utils/chats/index.js                   |  1 +
 .../utils/vectorDbProviders/chroma/index.js   | 16 +++++++--
 server/utils/vectorDbProviders/lance/index.js | 12 +++++--
 .../utils/vectorDbProviders/pinecone/index.js | 13 +++++--
 .../utils/vectorDbProviders/qdrant/index.js   | 12 +++++--
 .../utils/vectorDbProviders/weaviate/index.js | 12 +++++--
 10 files changed, 107 insertions(+), 21 deletions(-)
 create mode 100644 server/prisma/migrations/20231101001441_init/migration.sql

diff --git a/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx b/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx
index 27f3892c6..29d2bfa73 100644
--- a/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx
+++ b/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx
@@ -17,6 +17,9 @@ function castToType(key, value) {
     openAiHistory: {
       cast: (value) => Number(value),
     },
+    similarityThreshold: {
+      cast: (value) => parseFloat(value),
+    },
   };
 
   if (!definitions.hasOwnProperty(key)) return value;
@@ -233,6 +236,39 @@ export default function WorkspaceSettings({ workspace }) {
                   autoComplete="off"
                   onChange={() => setHasChanges(true)}
                 />
+                <div className="mt-4">
+                  <div className="flex flex-col">
+                    <label
+                      htmlFor="name"
+                      className="block text-sm font-medium text-white"
+                    >
+                      Document similarity threshold
+                    </label>
+                    <p className="text-white text-opacity-60 text-xs font-medium py-1.5">
+                      The minimum similarity score required for a source to be
+                      considered related to the chat. The higher the number, the
+                      more similar the source must be to the chat.
+                    </p>
+                  </div>
+                  <select
+                    name="similarityThreshold"
+                    defaultValue={workspace?.similarityThreshold ?? 0.25}
+                    className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5"
+                    onChange={() => setHasChanges(true)}
+                    required={true}
+                  >
+                    <option value={0.0}>No restriction</option>
+                    <option value={0.25}>
+                      Low (similarity score &ge; .25)
+                    </option>
+                    <option value={0.5}>
+                      Medium (similarity score &ge; .50)
+                    </option>
+                    <option value={0.75}>
+                      High (similarity score &ge; .75)
+                    </option>
+                  </select>
+                </div>
               </div>
             </div>
           </div>
diff --git a/server/models/workspace.js b/server/models/workspace.js
index b11283e70..059af4c32 100644
--- a/server/models/workspace.js
+++ b/server/models/workspace.js
@@ -13,6 +13,7 @@ const Workspace = {
     "openAiHistory",
     "lastUpdatedAt",
     "openAiPrompt",
+    "similarityThreshold",
   ],
 
   new: async function (name = null, creatorId = null) {
diff --git a/server/prisma/migrations/20231101001441_init/migration.sql b/server/prisma/migrations/20231101001441_init/migration.sql
new file mode 100644
index 000000000..1c20f46e2
--- /dev/null
+++ b/server/prisma/migrations/20231101001441_init/migration.sql
@@ -0,0 +1,2 @@
+-- AlterTable
+ALTER TABLE "workspaces" ADD COLUMN "similarityThreshold" REAL DEFAULT 0.25;
diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma
index 0f3190c97..b2661e384 100644
--- a/server/prisma/schema.prisma
+++ b/server/prisma/schema.prisma
@@ -82,17 +82,18 @@ model welcome_messages {
 }
 
 model workspaces {
-  id              Int                   @id @default(autoincrement())
-  name            String
-  slug            String                @unique
-  vectorTag       String?
-  createdAt       DateTime              @default(now())
-  openAiTemp      Float?
-  openAiHistory   Int                   @default(20)
-  lastUpdatedAt   DateTime              @default(now())
-  openAiPrompt    String?
-  workspace_users workspace_users[]
-  documents       workspace_documents[]
+  id                  Int                   @id @default(autoincrement())
+  name                String
+  slug                String                @unique
+  vectorTag           String?
+  createdAt           DateTime              @default(now())
+  openAiTemp          Float?
+  openAiHistory       Int                   @default(20)
+  lastUpdatedAt       DateTime              @default(now())
+  openAiPrompt        String?
+  similarityThreshold Float?                @default(0.25)
+  workspace_users     workspace_users[]
+  documents           workspace_documents[]
 }
 
 model workspace_chats {
diff --git a/server/utils/chats/index.js b/server/utils/chats/index.js
index b2c8b8d3e..a1ed4758d 100644
--- a/server/utils/chats/index.js
+++ b/server/utils/chats/index.js
@@ -116,6 +116,7 @@ async function chatWithWorkspace(
     namespace: workspace.slug,
     input: message,
     LLMConnector,
+    similarityThreshold: workspace?.similarityThreshold,
   });
 
   // Failed similarity search.
diff --git a/server/utils/vectorDbProviders/chroma/index.js b/server/utils/vectorDbProviders/chroma/index.js
index 2bdb0133d..c2f0257dd 100644
--- a/server/utils/vectorDbProviders/chroma/index.js
+++ b/server/utils/vectorDbProviders/chroma/index.js
@@ -59,7 +59,12 @@ const Chroma = {
     const namespace = await this.namespace(client, _namespace);
     return namespace?.vectorCount || 0;
   },
-  similarityResponse: async function (client, namespace, queryVector) {
+  similarityResponse: async function (
+    client,
+    namespace,
+    queryVector,
+    similarityThreshold = 0.25
+  ) {
     const collection = await client.getCollection({ name: namespace });
     const result = {
       contextTexts: [],
@@ -72,6 +77,11 @@ const Chroma = {
       nResults: 4,
     });
     response.ids[0].forEach((_, i) => {
+      if (
+        this.distanceToSimilarity(response.distances[0][i]) <
+        similarityThreshold
+      )
+        return;
       result.contextTexts.push(response.documents[0][i]);
       result.sourceDocuments.push(response.metadatas[0][i]);
       result.scores.push(this.distanceToSimilarity(response.distances[0][i]));
@@ -256,6 +266,7 @@ const Chroma = {
     namespace = null,
     input = "",
     LLMConnector = null,
+    similarityThreshold = 0.25,
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -273,7 +284,8 @@ const Chroma = {
     const { contextTexts, sourceDocuments } = await this.similarityResponse(
       client,
       namespace,
-      queryVector
+      queryVector,
+      similarityThreshold
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
diff --git a/server/utils/vectorDbProviders/lance/index.js b/server/utils/vectorDbProviders/lance/index.js
index c18766a84..69adec662 100644
--- a/server/utils/vectorDbProviders/lance/index.js
+++ b/server/utils/vectorDbProviders/lance/index.js
@@ -54,7 +54,12 @@ const LanceDb = {
   embedder: function () {
     return new OpenAIEmbeddings({ openAIApiKey: process.env.OPEN_AI_KEY });
   },
-  similarityResponse: async function (client, namespace, queryVector) {
+  similarityResponse: async function (
+    client,
+    namespace,
+    queryVector,
+    similarityThreshold = 0.25
+  ) {
     const collection = await client.openTable(namespace);
     const result = {
       contextTexts: [],
@@ -69,6 +74,7 @@ const LanceDb = {
       .execute();
 
     response.forEach((item) => {
+      if (this.distanceToSimilarity(item.score) < similarityThreshold) return;
       const { vector: _, ...rest } = item;
       result.contextTexts.push(rest.text);
       result.sourceDocuments.push(rest);
@@ -229,6 +235,7 @@ const LanceDb = {
     namespace = null,
     input = "",
     LLMConnector = null,
+    similarityThreshold = 0.25,
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -246,7 +253,8 @@ const LanceDb = {
     const { contextTexts, sourceDocuments } = await this.similarityResponse(
       client,
       namespace,
-      queryVector
+      queryVector,
+      similarityThreshold
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
diff --git a/server/utils/vectorDbProviders/pinecone/index.js b/server/utils/vectorDbProviders/pinecone/index.js
index f9600cf0c..3b0e09e96 100644
--- a/server/utils/vectorDbProviders/pinecone/index.js
+++ b/server/utils/vectorDbProviders/pinecone/index.js
@@ -36,7 +36,12 @@ const Pinecone = {
     const namespace = await this.namespace(pineconeIndex, _namespace);
     return namespace?.vectorCount || 0;
   },
-  similarityResponse: async function (index, namespace, queryVector) {
+  similarityResponse: async function (
+    index,
+    namespace,
+    queryVector,
+    similarityThreshold = 0.25
+  ) {
     const result = {
       contextTexts: [],
       sourceDocuments: [],
@@ -52,6 +57,7 @@ const Pinecone = {
     });
 
     response.matches.forEach((match) => {
+      if (match.score < similarityThreshold) return;
       result.contextTexts.push(match.metadata.text);
       result.sourceDocuments.push(match);
       result.scores.push(match.score);
@@ -59,6 +65,7 @@ const Pinecone = {
 
     return result;
   },
+
   namespace: async function (index, namespace = null) {
     if (!namespace) throw new Error("No namespace value provided.");
     const { namespaces } = await index.describeIndexStats1();
@@ -225,6 +232,7 @@ const Pinecone = {
     namespace = null,
     input = "",
     LLMConnector = null,
+    similarityThreshold = 0.25,
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -239,7 +247,8 @@ const Pinecone = {
     const { contextTexts, sourceDocuments } = await this.similarityResponse(
       pineconeIndex,
       namespace,
-      queryVector
+      queryVector,
+      similarityThreshold
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
diff --git a/server/utils/vectorDbProviders/qdrant/index.js b/server/utils/vectorDbProviders/qdrant/index.js
index c565daa7a..86d941548 100644
--- a/server/utils/vectorDbProviders/qdrant/index.js
+++ b/server/utils/vectorDbProviders/qdrant/index.js
@@ -45,7 +45,12 @@ const QDrant = {
     const namespace = await this.namespace(client, _namespace);
     return namespace?.vectorCount || 0;
   },
-  similarityResponse: async function (_client, namespace, queryVector) {
+  similarityResponse: async function (
+    _client,
+    namespace,
+    queryVector,
+    similarityThreshold = 0.25
+  ) {
     const { client } = await this.connect();
     const result = {
       contextTexts: [],
@@ -60,6 +65,7 @@ const QDrant = {
     });
 
     responses.forEach((response) => {
+      if (response.score < similarityThreshold) return;
       result.contextTexts.push(response?.payload?.text || "");
       result.sourceDocuments.push({
         ...(response?.payload || {}),
@@ -265,6 +271,7 @@ const QDrant = {
     namespace = null,
     input = "",
     LLMConnector = null,
+    similarityThreshold = 0.25,
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -282,7 +289,8 @@ const QDrant = {
     const { contextTexts, sourceDocuments } = await this.similarityResponse(
       client,
       namespace,
-      queryVector
+      queryVector,
+      similarityThreshold
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
diff --git a/server/utils/vectorDbProviders/weaviate/index.js b/server/utils/vectorDbProviders/weaviate/index.js
index 052ad5861..93c63e8b9 100644
--- a/server/utils/vectorDbProviders/weaviate/index.js
+++ b/server/utils/vectorDbProviders/weaviate/index.js
@@ -72,7 +72,12 @@ const Weaviate = {
       return 0;
     }
   },
-  similarityResponse: async function (client, namespace, queryVector) {
+  similarityResponse: async function (
+    client,
+    namespace,
+    queryVector,
+    similarityThreshold = 0.25
+  ) {
     const result = {
       contextTexts: [],
       sourceDocuments: [],
@@ -97,6 +102,7 @@ const Weaviate = {
         _additional: { id, certainty },
         ...rest
       } = response;
+      if (certainty < similarityThreshold) return;
       result.contextTexts.push(rest.text);
       result.sourceDocuments.push({ ...rest, id });
       result.scores.push(certainty);
@@ -336,6 +342,7 @@ const Weaviate = {
     namespace = null,
     input = "",
     LLMConnector = null,
+    similarityThreshold = 0.25,
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -353,7 +360,8 @@ const Weaviate = {
     const { contextTexts, sourceDocuments } = await this.similarityResponse(
       client,
       namespace,
-      queryVector
+      queryVector,
+      similarityThreshold
     );
 
     const sources = sourceDocuments.map((metadata, i) => {