diff --git a/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx b/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx index 27f3892c6..29d2bfa73 100644 --- a/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx +++ b/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx @@ -17,6 +17,9 @@ function castToType(key, value) { openAiHistory: { cast: (value) => Number(value), }, + similarityThreshold: { + cast: (value) => parseFloat(value), + }, }; if (!definitions.hasOwnProperty(key)) return value; @@ -233,6 +236,39 @@ export default function WorkspaceSettings({ workspace }) { autoComplete="off" onChange={() => setHasChanges(true)} /> + <div className="mt-4"> + <div className="flex flex-col"> + <label + htmlFor="name" + className="block text-sm font-medium text-white" + > + Document similarity threshold + </label> + <p className="text-white text-opacity-60 text-xs font-medium py-1.5"> + The minimum similarity score required for a source to be + considered related to the chat. The higher the number, the + more similar the source must be to the chat. + </p> + </div> + <select + name="similarityThreshold" + defaultValue={workspace?.similarityThreshold ?? 0.25} + className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5" + onChange={() => setHasChanges(true)} + required={true} + > + <option value={0.0}>No restriction</option> + <option value={0.25}> + Low (similarity score ≥ .25) + </option> + <option value={0.5}> + Medium (similarity score ≥ .50) + </option> + <option value={0.75}> + High (similarity score ≥ .75) + </option> + </select> + </div> </div> </div> </div> diff --git a/server/models/workspace.js b/server/models/workspace.js index b11283e70..059af4c32 100644 --- a/server/models/workspace.js +++ b/server/models/workspace.js @@ -13,6 +13,7 @@ const Workspace = { "openAiHistory", "lastUpdatedAt", "openAiPrompt", + "similarityThreshold", ], new: async function (name = null, creatorId = null) { diff --git a/server/prisma/migrations/20231101001441_init/migration.sql b/server/prisma/migrations/20231101001441_init/migration.sql new file mode 100644 index 000000000..1c20f46e2 --- /dev/null +++ b/server/prisma/migrations/20231101001441_init/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "workspaces" ADD COLUMN "similarityThreshold" REAL DEFAULT 0.25; diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma index 0f3190c97..b2661e384 100644 --- a/server/prisma/schema.prisma +++ b/server/prisma/schema.prisma @@ -82,17 +82,18 @@ model welcome_messages { } model workspaces { - id Int @id @default(autoincrement()) - name String - slug String @unique - vectorTag String? - createdAt DateTime @default(now()) - openAiTemp Float? - openAiHistory Int @default(20) - lastUpdatedAt DateTime @default(now()) - openAiPrompt String? - workspace_users workspace_users[] - documents workspace_documents[] + id Int @id @default(autoincrement()) + name String + slug String @unique + vectorTag String? + createdAt DateTime @default(now()) + openAiTemp Float? + openAiHistory Int @default(20) + lastUpdatedAt DateTime @default(now()) + openAiPrompt String? + similarityThreshold Float? @default(0.25) + workspace_users workspace_users[] + documents workspace_documents[] } model workspace_chats { diff --git a/server/utils/chats/index.js b/server/utils/chats/index.js index b2c8b8d3e..a1ed4758d 100644 --- a/server/utils/chats/index.js +++ b/server/utils/chats/index.js @@ -116,6 +116,7 @@ async function chatWithWorkspace( namespace: workspace.slug, input: message, LLMConnector, + similarityThreshold: workspace?.similarityThreshold, }); // Failed similarity search. diff --git a/server/utils/vectorDbProviders/chroma/index.js b/server/utils/vectorDbProviders/chroma/index.js index 2bdb0133d..c2f0257dd 100644 --- a/server/utils/vectorDbProviders/chroma/index.js +++ b/server/utils/vectorDbProviders/chroma/index.js @@ -59,7 +59,12 @@ const Chroma = { const namespace = await this.namespace(client, _namespace); return namespace?.vectorCount || 0; }, - similarityResponse: async function (client, namespace, queryVector) { + similarityResponse: async function ( + client, + namespace, + queryVector, + similarityThreshold = 0.25 + ) { const collection = await client.getCollection({ name: namespace }); const result = { contextTexts: [], @@ -72,6 +77,11 @@ const Chroma = { nResults: 4, }); response.ids[0].forEach((_, i) => { + if ( + this.distanceToSimilarity(response.distances[0][i]) < + similarityThreshold + ) + return; result.contextTexts.push(response.documents[0][i]); result.sourceDocuments.push(response.metadatas[0][i]); result.scores.push(this.distanceToSimilarity(response.distances[0][i])); @@ -256,6 +266,7 @@ const Chroma = { namespace = null, input = "", LLMConnector = null, + similarityThreshold = 0.25, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -273,7 +284,8 @@ const Chroma = { const { contextTexts, sourceDocuments } = await this.similarityResponse( client, namespace, - queryVector + queryVector, + similarityThreshold ); const sources = sourceDocuments.map((metadata, i) => { diff --git a/server/utils/vectorDbProviders/lance/index.js b/server/utils/vectorDbProviders/lance/index.js index c18766a84..69adec662 100644 --- a/server/utils/vectorDbProviders/lance/index.js +++ b/server/utils/vectorDbProviders/lance/index.js @@ -54,7 +54,12 @@ const LanceDb = { embedder: function () { return new OpenAIEmbeddings({ openAIApiKey: process.env.OPEN_AI_KEY }); }, - similarityResponse: async function (client, namespace, queryVector) { + similarityResponse: async function ( + client, + namespace, + queryVector, + similarityThreshold = 0.25 + ) { const collection = await client.openTable(namespace); const result = { contextTexts: [], @@ -69,6 +74,7 @@ const LanceDb = { .execute(); response.forEach((item) => { + if (this.distanceToSimilarity(item.score) < similarityThreshold) return; const { vector: _, ...rest } = item; result.contextTexts.push(rest.text); result.sourceDocuments.push(rest); @@ -229,6 +235,7 @@ const LanceDb = { namespace = null, input = "", LLMConnector = null, + similarityThreshold = 0.25, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -246,7 +253,8 @@ const LanceDb = { const { contextTexts, sourceDocuments } = await this.similarityResponse( client, namespace, - queryVector + queryVector, + similarityThreshold ); const sources = sourceDocuments.map((metadata, i) => { diff --git a/server/utils/vectorDbProviders/pinecone/index.js b/server/utils/vectorDbProviders/pinecone/index.js index f9600cf0c..3b0e09e96 100644 --- a/server/utils/vectorDbProviders/pinecone/index.js +++ b/server/utils/vectorDbProviders/pinecone/index.js @@ -36,7 +36,12 @@ const Pinecone = { const namespace = await this.namespace(pineconeIndex, _namespace); return namespace?.vectorCount || 0; }, - similarityResponse: async function (index, namespace, queryVector) { + similarityResponse: async function ( + index, + namespace, + queryVector, + similarityThreshold = 0.25 + ) { const result = { contextTexts: [], sourceDocuments: [], @@ -52,6 +57,7 @@ const Pinecone = { }); response.matches.forEach((match) => { + if (match.score < similarityThreshold) return; result.contextTexts.push(match.metadata.text); result.sourceDocuments.push(match); result.scores.push(match.score); @@ -59,6 +65,7 @@ const Pinecone = { return result; }, + namespace: async function (index, namespace = null) { if (!namespace) throw new Error("No namespace value provided."); const { namespaces } = await index.describeIndexStats1(); @@ -225,6 +232,7 @@ const Pinecone = { namespace = null, input = "", LLMConnector = null, + similarityThreshold = 0.25, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -239,7 +247,8 @@ const Pinecone = { const { contextTexts, sourceDocuments } = await this.similarityResponse( pineconeIndex, namespace, - queryVector + queryVector, + similarityThreshold ); const sources = sourceDocuments.map((metadata, i) => { diff --git a/server/utils/vectorDbProviders/qdrant/index.js b/server/utils/vectorDbProviders/qdrant/index.js index c565daa7a..86d941548 100644 --- a/server/utils/vectorDbProviders/qdrant/index.js +++ b/server/utils/vectorDbProviders/qdrant/index.js @@ -45,7 +45,12 @@ const QDrant = { const namespace = await this.namespace(client, _namespace); return namespace?.vectorCount || 0; }, - similarityResponse: async function (_client, namespace, queryVector) { + similarityResponse: async function ( + _client, + namespace, + queryVector, + similarityThreshold = 0.25 + ) { const { client } = await this.connect(); const result = { contextTexts: [], @@ -60,6 +65,7 @@ const QDrant = { }); responses.forEach((response) => { + if (response.score < similarityThreshold) return; result.contextTexts.push(response?.payload?.text || ""); result.sourceDocuments.push({ ...(response?.payload || {}), @@ -265,6 +271,7 @@ const QDrant = { namespace = null, input = "", LLMConnector = null, + similarityThreshold = 0.25, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -282,7 +289,8 @@ const QDrant = { const { contextTexts, sourceDocuments } = await this.similarityResponse( client, namespace, - queryVector + queryVector, + similarityThreshold ); const sources = sourceDocuments.map((metadata, i) => { diff --git a/server/utils/vectorDbProviders/weaviate/index.js b/server/utils/vectorDbProviders/weaviate/index.js index 052ad5861..93c63e8b9 100644 --- a/server/utils/vectorDbProviders/weaviate/index.js +++ b/server/utils/vectorDbProviders/weaviate/index.js @@ -72,7 +72,12 @@ const Weaviate = { return 0; } }, - similarityResponse: async function (client, namespace, queryVector) { + similarityResponse: async function ( + client, + namespace, + queryVector, + similarityThreshold = 0.25 + ) { const result = { contextTexts: [], sourceDocuments: [], @@ -97,6 +102,7 @@ const Weaviate = { _additional: { id, certainty }, ...rest } = response; + if (certainty < similarityThreshold) return; result.contextTexts.push(rest.text); result.sourceDocuments.push({ ...rest, id }); result.scores.push(certainty); @@ -336,6 +342,7 @@ const Weaviate = { namespace = null, input = "", LLMConnector = null, + similarityThreshold = 0.25, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -353,7 +360,8 @@ const Weaviate = { const { contextTexts, sourceDocuments } = await this.similarityResponse( client, namespace, - queryVector + queryVector, + similarityThreshold ); const sources = sourceDocuments.map((metadata, i) => {