Reranker option for RAG ()

* Reranker WIP

* add cacheing and singleton loading

* Add field to workspaces for vectorSearchMode
Add UI for lancedb to change mode
update all search endpoints to pass in reranker prop if provider can use it

* update hint text

* When reranking, swap score to rerank score

* update optchain
This commit is contained in:
Timothy Carambat 2025-01-02 14:27:52 -08:00 committed by GitHub
parent bb5c3b7e0d
commit ad01df8790
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 339 additions and 9 deletions
frontend/src/pages/WorkspaceSettings
VectorDatabase
VectorSearchMode
index.jsx
index.jsx
server
endpoints/api/workspace
models
prisma
migrations/20250102204948_init
schema.prisma
storage/models
utils
EmbeddingRerankers/native
agents/aibitat/plugins
chats
helpers
vectorDbProviders/lance

View file

@ -0,0 +1,51 @@
import { useState } from "react";
// We dont support all vectorDBs yet for reranking due to complexities of how each provider
// returns information. We need to normalize the response data so Reranker can be used for each provider.
const supportedVectorDBs = ["lancedb"];
const hint = {
default: {
title: "Default",
description:
"This is the fastest performance, but may not return the most relevant results leading to model hallucinations.",
},
rerank: {
title: "Accuracy Optimized",
description:
"LLM responses may take longer to generate, but your responses will be more accurate and relevant.",
},
};
export default function VectorSearchMode({ workspace, setHasChanges }) {
const [selection, setSelection] = useState(
workspace?.vectorSearchMode ?? "default"
);
if (!workspace?.vectorDB || !supportedVectorDBs.includes(workspace?.vectorDB))
return null;
return (
<div>
<div className="flex flex-col">
<label htmlFor="name" className="block input-label">
Search Preference
</label>
</div>
<select
name="vectorSearchMode"
value={selection}
className="border-none bg-theme-settings-input-bg text-white text-sm mt-2 rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5"
onChange={(e) => {
setSelection(e.target.value);
setHasChanges(true);
}}
required={true}
>
<option value="default">Default</option>
<option value="rerank">Accuracy Optimized</option>
</select>
<p className="text-white text-opacity-60 text-xs font-medium py-1.5">
{hint[selection]?.description}
</p>
</div>
);
}

View file

@ -7,6 +7,7 @@ import MaxContextSnippets from "./MaxContextSnippets";
import DocumentSimilarityThreshold from "./DocumentSimilarityThreshold";
import ResetDatabase from "./ResetDatabase";
import VectorCount from "./VectorCount";
import VectorSearchMode from "./VectorSearchMode";
export default function VectorDatabase({ workspace }) {
const [hasChanges, setHasChanges] = useState(false);
@ -43,6 +44,7 @@ export default function VectorDatabase({ workspace }) {
<VectorDBIdentifier workspace={workspace} />
<VectorCount reload={true} workspace={workspace} />
</div>
<VectorSearchMode workspace={workspace} setHasChanges={setHasChanges} />
<MaxContextSnippets workspace={workspace} setHasChanges={setHasChanges} />
<DocumentSimilarityThreshold
workspace={workspace}

View file

@ -23,6 +23,7 @@ import Members from "./Members";
import WorkspaceAgentConfiguration from "./AgentConfig";
import useUser from "@/hooks/useUser";
import { useTranslation } from "react-i18next";
import System from "@/models/system";
const TABS = {
"general-appearance": GeneralAppearance,
@ -59,9 +60,11 @@ function ShowWorkspaceChat() {
return;
}
const _settings = await System.keys();
const suggestedMessages = await Workspace.getSuggestedMessages(slug);
setWorkspace({
..._workspace,
vectorDB: _settings?.VectorDB,
suggestedMessages,
});
setLoading(false);

View file

@ -961,6 +961,7 @@ function apiWorkspaceEndpoints(app) {
LLMConnector: getLLMProvider(),
similarityThreshold: parseSimilarityThreshold(),
topN: parseTopN(),
rerank: workspace?.vectorSearchMode === "rerank",
});
response.status(200).json({

View file

@ -34,6 +34,7 @@ const Workspace = {
"agentProvider",
"agentModel",
"queryRefusalResponse",
"vectorSearchMode",
],
validations: {
@ -99,6 +100,15 @@ const Workspace = {
if (!value || typeof value !== "string") return null;
return String(value);
},
vectorSearchMode: (value) => {
if (
!value ||
typeof value !== "string" ||
!["default", "rerank"].includes(value)
)
return "default";
return value;
},
},
/**

View file

@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "workspaces" ADD COLUMN "vectorSearchMode" TEXT DEFAULT 'default';

View file

@ -137,6 +137,7 @@ model workspaces {
agentProvider String?
agentModel String?
queryRefusalResponse String?
vectorSearchMode String? @default("default")
workspace_users workspace_users[]
documents workspace_documents[]
workspace_suggested_messages workspace_suggested_messages[]

View file

@ -3,4 +3,5 @@ downloaded/*
!downloaded/.placeholder
openrouter
apipie
novita
novita
mixedbread-ai*

View file

@ -0,0 +1,153 @@
const path = require("path");
const fs = require("fs");
class NativeEmbeddingReranker {
static #model = null;
static #tokenizer = null;
static #transformers = null;
constructor() {
// An alternative model to the mixedbread-ai/mxbai-rerank-xsmall-v1 model (speed on CPU is much slower for this model @ 18docs = 6s)
// Model Card: https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2 (speed on CPU is much faster @ 18docs = 1.6s)
this.model = "Xenova/ms-marco-MiniLM-L-6-v2";
this.cacheDir = path.resolve(
process.env.STORAGE_DIR
? path.resolve(process.env.STORAGE_DIR, `models`)
: path.resolve(__dirname, `../../../storage/models`)
);
this.modelPath = path.resolve(this.cacheDir, ...this.model.split("/"));
// Make directory when it does not exist in existing installations
if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir);
this.log("Initialized");
}
log(text, ...args) {
console.log(`\x1b[36m[NativeEmbeddingReranker]\x1b[0m ${text}`, ...args);
}
/**
* This function will preload the reranker suite and tokenizer.
* This is useful for reducing the latency of the first rerank call and pre-downloading the models and such
* to avoid having to wait for the models to download on the first rerank call.
*/
async preload() {
try {
this.log(`Preloading reranker suite...`);
await this.initClient();
this.log(
`Preloaded reranker suite. Reranking is available as a service now.`
);
return;
} catch (e) {
console.error(e);
this.log(
`Failed to preload reranker suite. Reranking will be available on the first rerank call.`
);
return;
}
}
async initClient() {
if (NativeEmbeddingReranker.#transformers) {
this.log(`Reranker suite already initialized - reusing.`);
return;
}
await import("@xenova/transformers").then(
async ({ AutoModelForSequenceClassification, AutoTokenizer }) => {
this.log(`Loading reranker suite...`);
NativeEmbeddingReranker.#transformers = {
AutoModelForSequenceClassification,
AutoTokenizer,
};
await this.#getPreTrainedModel();
await this.#getPreTrainedTokenizer();
}
);
return;
}
async #getPreTrainedModel() {
if (NativeEmbeddingReranker.#model) {
this.log(`Loading model from singleton...`);
return NativeEmbeddingReranker.#model;
}
const model =
await NativeEmbeddingReranker.#transformers.AutoModelForSequenceClassification.from_pretrained(
this.model,
{
progress_callback: (p) =>
p.status === "progress" &&
this.log(`Loading model ${this.model}... ${p?.progress}%`),
cache_dir: this.cacheDir,
}
);
this.log(`Loaded model ${this.model}`);
NativeEmbeddingReranker.#model = model;
return model;
}
async #getPreTrainedTokenizer() {
if (NativeEmbeddingReranker.#tokenizer) {
this.log(`Loading tokenizer from singleton...`);
return NativeEmbeddingReranker.#tokenizer;
}
const tokenizer =
await NativeEmbeddingReranker.#transformers.AutoTokenizer.from_pretrained(
this.model,
{
progress_callback: (p) =>
p.status === "progress" &&
this.log(`Loading tokenizer ${this.model}... ${p?.progress}%`),
cache_dir: this.cacheDir,
}
);
this.log(`Loaded tokenizer ${this.model}`);
NativeEmbeddingReranker.#tokenizer = tokenizer;
return tokenizer;
}
/**
* Reranks a list of documents based on the query.
* @param {string} query - The query to rerank the documents against.
* @param {{text: string}[]} documents - The list of document text snippets to rerank. Should be output from a vector search.
* @param {Object} options - The options for the reranking.
* @param {number} options.topK - The number of top documents to return.
* @returns {Promise<any[]>} - The reranked list of documents.
*/
async rerank(query, documents, options = { topK: 4 }) {
await this.initClient();
const model = NativeEmbeddingReranker.#model;
const tokenizer = NativeEmbeddingReranker.#tokenizer;
const start = Date.now();
this.log(`Reranking ${documents.length} documents...`);
const inputs = tokenizer(new Array(documents.length).fill(query), {
text_pair: documents.map((doc) => doc.text),
padding: true,
truncation: true,
});
const { logits } = await model(inputs);
const reranked = logits
.sigmoid()
.tolist()
.map(([score], i) => ({
rerank_corpus_id: i,
rerank_score: score,
...documents[i],
}))
.sort((a, b) => b.rerank_score - a.rerank_score)
.slice(0, options.topK);
this.log(
`Reranking ${documents.length} documents to top ${options.topK} took ${Date.now() - start}ms`
);
return reranked;
}
}
module.exports = {
NativeEmbeddingReranker,
};

View file

@ -95,6 +95,7 @@ const memory = {
input: query,
LLMConnector,
topN: workspace?.topN ?? 4,
rerank: workspace?.vectorSearchMode === "rerank",
});
if (contextTexts.length === 0) {

View file

@ -180,6 +180,7 @@ async function chatSync({
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
rerank: workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],
@ -480,6 +481,7 @@ async function streamChat({
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
rerank: workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],

View file

@ -93,6 +93,7 @@ async function streamChatWithForEmbed(
similarityThreshold: embed.workspace?.similarityThreshold,
topN: embed.workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
rerank: embed.workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],

View file

@ -89,6 +89,7 @@ async function chatSync({
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
rerank: workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],
@ -304,6 +305,7 @@ async function streamChat({
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
rerank: workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],

View file

@ -139,6 +139,7 @@ async function streamChatWithWorkspace(
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
rerank: workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],

View file

@ -56,6 +56,7 @@
* @property {Function} totalVectors - Returns the total number of vectors in the database.
* @property {Function} namespaceCount - Returns the count of vectors in a given namespace.
* @property {Function} similarityResponse - Performs a similarity search on a given namespace.
* @property {Function} rerankedSimilarityResponse - Performs a similarity search on a given namespace with reranking (if supported by provider).
* @property {Function} namespace - Retrieves the specified namespace collection.
* @property {Function} hasNamespace - Checks if a namespace exists.
* @property {Function} namespaceExists - Verifies if a namespace exists in the client.

View file

@ -5,6 +5,7 @@ const { SystemSettings } = require("../../../models/systemSettings");
const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid");
const { sourceIdentifier } = require("../../chats");
const { NativeEmbeddingReranker } = require("../../EmbeddingRerankers/native");
/**
* LancedDB Client connection object
@ -57,6 +58,91 @@ const LanceDb = {
const table = await client.openTable(_namespace);
return (await table.countRows()) || 0;
},
/**
* Performs a SimilaritySearch + Reranking on a namespace.
* @param {Object} params - The parameters for the rerankedSimilarityResponse.
* @param {Object} params.client - The vectorDB client.
* @param {string} params.namespace - The namespace to search in.
* @param {string} params.query - The query to search for (plain text).
* @param {number[]} params.queryVector - The vector of the query.
* @param {number} params.similarityThreshold - The threshold for similarity.
* @param {number} params.topN - the number of results to return from this process.
* @param {string[]} params.filterIdentifiers - The identifiers of the documents to filter out.
* @returns
*/
rerankedSimilarityResponse: async function ({
client,
namespace,
query,
queryVector,
topN = 4,
similarityThreshold = 0.25,
filterIdentifiers = [],
}) {
const reranker = new NativeEmbeddingReranker();
const collection = await client.openTable(namespace);
const totalEmbeddings = await this.namespaceCount(namespace);
const result = {
contextTexts: [],
sourceDocuments: [],
scores: [],
};
/**
* For reranking, we want to work with a larger number of results than the topN.
* This is because the reranker can only rerank the results it it given and we dont auto-expand the results.
* We want to give the reranker a larger number of results to work with.
*
* However, we cannot make this boundless as reranking is expensive and time consuming.
* So we limit the number of results to a maximum of 50 and a minimum of 10.
* This is a good balance between the number of results to rerank and the cost of reranking
* and ensures workspaces with 10K embeddings will still rerank within a reasonable timeframe on base level hardware.
*
* Benchmarks:
* On Intel Mac: 2.6 GHz 6-Core Intel Core i7 - 20 docs reranked in ~5.2 sec
*/
const searchLimit = Math.max(
10,
Math.min(50, Math.ceil(totalEmbeddings * 0.1))
);
const vectorSearchResults = await collection
.vectorSearch(queryVector)
.distanceType("cosine")
.limit(searchLimit)
.toArray();
await reranker
.rerank(query, vectorSearchResults, { topK: topN })
.then((rerankResults) => {
rerankResults.forEach((item) => {
if (this.distanceToSimilarity(item._distance) < similarityThreshold)
return;
const { vector: _, ...rest } = item;
if (filterIdentifiers.includes(sourceIdentifier(rest))) {
console.log(
"LanceDB: A source was filtered from context as it's parent document is pinned."
);
return;
}
const score =
item?.rerank_score || this.distanceToSimilarity(item._distance);
result.contextTexts.push(rest.text);
result.sourceDocuments.push({
...rest,
score,
});
result.scores.push(score);
});
})
.catch((e) => {
console.error(e);
console.error("LanceDB::rerankedSimilarityResponse", e.message);
});
return result;
},
/**
* Performs a SimilaritySearch on a give LanceDB namespace.
* @param {Object} params
@ -300,6 +386,7 @@ const LanceDb = {
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
rerank = false,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
@ -314,15 +401,26 @@ const LanceDb = {
}
const queryVector = await LLMConnector.embedTextInput(input);
const { contextTexts, sourceDocuments } = await this.similarityResponse({
client,
namespace,
queryVector,
similarityThreshold,
topN,
filterIdentifiers,
});
const result = rerank
? await this.rerankedSimilarityResponse({
client,
namespace,
query: input,
queryVector,
similarityThreshold,
topN,
filterIdentifiers,
})
: await this.similarityResponse({
client,
namespace,
queryVector,
similarityThreshold,
topN,
filterIdentifiers,
});
const { contextTexts, sourceDocuments } = result;
const sources = sourceDocuments.map((metadata, i) => {
return { metadata: { ...metadata, text: contextTexts[i] } };
});