mirror of
https://github.com/Mintplex-Labs/anything-llm.git
synced 2025-04-17 18:18:11 +00:00
Reranker option for RAG (#2929)
* Reranker WIP * add cacheing and singleton loading * Add field to workspaces for vectorSearchMode Add UI for lancedb to change mode update all search endpoints to pass in reranker prop if provider can use it * update hint text * When reranking, swap score to rerank score * update optchain
This commit is contained in:
parent
bb5c3b7e0d
commit
ad01df8790
16 changed files with 339 additions and 9 deletions
frontend/src/pages/WorkspaceSettings
server
endpoints/api/workspace
models
prisma
storage/models
utils
EmbeddingRerankers/native
agents/aibitat/plugins
chats
helpers
vectorDbProviders/lance
|
@ -0,0 +1,51 @@
|
|||
import { useState } from "react";
|
||||
|
||||
// We dont support all vectorDBs yet for reranking due to complexities of how each provider
|
||||
// returns information. We need to normalize the response data so Reranker can be used for each provider.
|
||||
const supportedVectorDBs = ["lancedb"];
|
||||
const hint = {
|
||||
default: {
|
||||
title: "Default",
|
||||
description:
|
||||
"This is the fastest performance, but may not return the most relevant results leading to model hallucinations.",
|
||||
},
|
||||
rerank: {
|
||||
title: "Accuracy Optimized",
|
||||
description:
|
||||
"LLM responses may take longer to generate, but your responses will be more accurate and relevant.",
|
||||
},
|
||||
};
|
||||
|
||||
export default function VectorSearchMode({ workspace, setHasChanges }) {
|
||||
const [selection, setSelection] = useState(
|
||||
workspace?.vectorSearchMode ?? "default"
|
||||
);
|
||||
if (!workspace?.vectorDB || !supportedVectorDBs.includes(workspace?.vectorDB))
|
||||
return null;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="flex flex-col">
|
||||
<label htmlFor="name" className="block input-label">
|
||||
Search Preference
|
||||
</label>
|
||||
</div>
|
||||
<select
|
||||
name="vectorSearchMode"
|
||||
value={selection}
|
||||
className="border-none bg-theme-settings-input-bg text-white text-sm mt-2 rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5"
|
||||
onChange={(e) => {
|
||||
setSelection(e.target.value);
|
||||
setHasChanges(true);
|
||||
}}
|
||||
required={true}
|
||||
>
|
||||
<option value="default">Default</option>
|
||||
<option value="rerank">Accuracy Optimized</option>
|
||||
</select>
|
||||
<p className="text-white text-opacity-60 text-xs font-medium py-1.5">
|
||||
{hint[selection]?.description}
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
|
@ -7,6 +7,7 @@ import MaxContextSnippets from "./MaxContextSnippets";
|
|||
import DocumentSimilarityThreshold from "./DocumentSimilarityThreshold";
|
||||
import ResetDatabase from "./ResetDatabase";
|
||||
import VectorCount from "./VectorCount";
|
||||
import VectorSearchMode from "./VectorSearchMode";
|
||||
|
||||
export default function VectorDatabase({ workspace }) {
|
||||
const [hasChanges, setHasChanges] = useState(false);
|
||||
|
@ -43,6 +44,7 @@ export default function VectorDatabase({ workspace }) {
|
|||
<VectorDBIdentifier workspace={workspace} />
|
||||
<VectorCount reload={true} workspace={workspace} />
|
||||
</div>
|
||||
<VectorSearchMode workspace={workspace} setHasChanges={setHasChanges} />
|
||||
<MaxContextSnippets workspace={workspace} setHasChanges={setHasChanges} />
|
||||
<DocumentSimilarityThreshold
|
||||
workspace={workspace}
|
||||
|
|
|
@ -23,6 +23,7 @@ import Members from "./Members";
|
|||
import WorkspaceAgentConfiguration from "./AgentConfig";
|
||||
import useUser from "@/hooks/useUser";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import System from "@/models/system";
|
||||
|
||||
const TABS = {
|
||||
"general-appearance": GeneralAppearance,
|
||||
|
@ -59,9 +60,11 @@ function ShowWorkspaceChat() {
|
|||
return;
|
||||
}
|
||||
|
||||
const _settings = await System.keys();
|
||||
const suggestedMessages = await Workspace.getSuggestedMessages(slug);
|
||||
setWorkspace({
|
||||
..._workspace,
|
||||
vectorDB: _settings?.VectorDB,
|
||||
suggestedMessages,
|
||||
});
|
||||
setLoading(false);
|
||||
|
|
|
@ -961,6 +961,7 @@ function apiWorkspaceEndpoints(app) {
|
|||
LLMConnector: getLLMProvider(),
|
||||
similarityThreshold: parseSimilarityThreshold(),
|
||||
topN: parseTopN(),
|
||||
rerank: workspace?.vectorSearchMode === "rerank",
|
||||
});
|
||||
|
||||
response.status(200).json({
|
||||
|
|
|
@ -34,6 +34,7 @@ const Workspace = {
|
|||
"agentProvider",
|
||||
"agentModel",
|
||||
"queryRefusalResponse",
|
||||
"vectorSearchMode",
|
||||
],
|
||||
|
||||
validations: {
|
||||
|
@ -99,6 +100,15 @@ const Workspace = {
|
|||
if (!value || typeof value !== "string") return null;
|
||||
return String(value);
|
||||
},
|
||||
vectorSearchMode: (value) => {
|
||||
if (
|
||||
!value ||
|
||||
typeof value !== "string" ||
|
||||
!["default", "rerank"].includes(value)
|
||||
)
|
||||
return "default";
|
||||
return value;
|
||||
},
|
||||
},
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
-- AlterTable
|
||||
ALTER TABLE "workspaces" ADD COLUMN "vectorSearchMode" TEXT DEFAULT 'default';
|
|
@ -137,6 +137,7 @@ model workspaces {
|
|||
agentProvider String?
|
||||
agentModel String?
|
||||
queryRefusalResponse String?
|
||||
vectorSearchMode String? @default("default")
|
||||
workspace_users workspace_users[]
|
||||
documents workspace_documents[]
|
||||
workspace_suggested_messages workspace_suggested_messages[]
|
||||
|
|
3
server/storage/models/.gitignore
vendored
3
server/storage/models/.gitignore
vendored
|
@ -3,4 +3,5 @@ downloaded/*
|
|||
!downloaded/.placeholder
|
||||
openrouter
|
||||
apipie
|
||||
novita
|
||||
novita
|
||||
mixedbread-ai*
|
153
server/utils/EmbeddingRerankers/native/index.js
Normal file
153
server/utils/EmbeddingRerankers/native/index.js
Normal file
|
@ -0,0 +1,153 @@
|
|||
const path = require("path");
|
||||
const fs = require("fs");
|
||||
|
||||
class NativeEmbeddingReranker {
|
||||
static #model = null;
|
||||
static #tokenizer = null;
|
||||
static #transformers = null;
|
||||
|
||||
constructor() {
|
||||
// An alternative model to the mixedbread-ai/mxbai-rerank-xsmall-v1 model (speed on CPU is much slower for this model @ 18docs = 6s)
|
||||
// Model Card: https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2 (speed on CPU is much faster @ 18docs = 1.6s)
|
||||
this.model = "Xenova/ms-marco-MiniLM-L-6-v2";
|
||||
this.cacheDir = path.resolve(
|
||||
process.env.STORAGE_DIR
|
||||
? path.resolve(process.env.STORAGE_DIR, `models`)
|
||||
: path.resolve(__dirname, `../../../storage/models`)
|
||||
);
|
||||
this.modelPath = path.resolve(this.cacheDir, ...this.model.split("/"));
|
||||
// Make directory when it does not exist in existing installations
|
||||
if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir);
|
||||
this.log("Initialized");
|
||||
}
|
||||
|
||||
log(text, ...args) {
|
||||
console.log(`\x1b[36m[NativeEmbeddingReranker]\x1b[0m ${text}`, ...args);
|
||||
}
|
||||
|
||||
/**
|
||||
* This function will preload the reranker suite and tokenizer.
|
||||
* This is useful for reducing the latency of the first rerank call and pre-downloading the models and such
|
||||
* to avoid having to wait for the models to download on the first rerank call.
|
||||
*/
|
||||
async preload() {
|
||||
try {
|
||||
this.log(`Preloading reranker suite...`);
|
||||
await this.initClient();
|
||||
this.log(
|
||||
`Preloaded reranker suite. Reranking is available as a service now.`
|
||||
);
|
||||
return;
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
this.log(
|
||||
`Failed to preload reranker suite. Reranking will be available on the first rerank call.`
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
async initClient() {
|
||||
if (NativeEmbeddingReranker.#transformers) {
|
||||
this.log(`Reranker suite already initialized - reusing.`);
|
||||
return;
|
||||
}
|
||||
|
||||
await import("@xenova/transformers").then(
|
||||
async ({ AutoModelForSequenceClassification, AutoTokenizer }) => {
|
||||
this.log(`Loading reranker suite...`);
|
||||
NativeEmbeddingReranker.#transformers = {
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
};
|
||||
await this.#getPreTrainedModel();
|
||||
await this.#getPreTrainedTokenizer();
|
||||
}
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
async #getPreTrainedModel() {
|
||||
if (NativeEmbeddingReranker.#model) {
|
||||
this.log(`Loading model from singleton...`);
|
||||
return NativeEmbeddingReranker.#model;
|
||||
}
|
||||
|
||||
const model =
|
||||
await NativeEmbeddingReranker.#transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||
this.model,
|
||||
{
|
||||
progress_callback: (p) =>
|
||||
p.status === "progress" &&
|
||||
this.log(`Loading model ${this.model}... ${p?.progress}%`),
|
||||
cache_dir: this.cacheDir,
|
||||
}
|
||||
);
|
||||
this.log(`Loaded model ${this.model}`);
|
||||
NativeEmbeddingReranker.#model = model;
|
||||
return model;
|
||||
}
|
||||
|
||||
async #getPreTrainedTokenizer() {
|
||||
if (NativeEmbeddingReranker.#tokenizer) {
|
||||
this.log(`Loading tokenizer from singleton...`);
|
||||
return NativeEmbeddingReranker.#tokenizer;
|
||||
}
|
||||
|
||||
const tokenizer =
|
||||
await NativeEmbeddingReranker.#transformers.AutoTokenizer.from_pretrained(
|
||||
this.model,
|
||||
{
|
||||
progress_callback: (p) =>
|
||||
p.status === "progress" &&
|
||||
this.log(`Loading tokenizer ${this.model}... ${p?.progress}%`),
|
||||
cache_dir: this.cacheDir,
|
||||
}
|
||||
);
|
||||
this.log(`Loaded tokenizer ${this.model}`);
|
||||
NativeEmbeddingReranker.#tokenizer = tokenizer;
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reranks a list of documents based on the query.
|
||||
* @param {string} query - The query to rerank the documents against.
|
||||
* @param {{text: string}[]} documents - The list of document text snippets to rerank. Should be output from a vector search.
|
||||
* @param {Object} options - The options for the reranking.
|
||||
* @param {number} options.topK - The number of top documents to return.
|
||||
* @returns {Promise<any[]>} - The reranked list of documents.
|
||||
*/
|
||||
async rerank(query, documents, options = { topK: 4 }) {
|
||||
await this.initClient();
|
||||
const model = NativeEmbeddingReranker.#model;
|
||||
const tokenizer = NativeEmbeddingReranker.#tokenizer;
|
||||
|
||||
const start = Date.now();
|
||||
this.log(`Reranking ${documents.length} documents...`);
|
||||
const inputs = tokenizer(new Array(documents.length).fill(query), {
|
||||
text_pair: documents.map((doc) => doc.text),
|
||||
padding: true,
|
||||
truncation: true,
|
||||
});
|
||||
const { logits } = await model(inputs);
|
||||
const reranked = logits
|
||||
.sigmoid()
|
||||
.tolist()
|
||||
.map(([score], i) => ({
|
||||
rerank_corpus_id: i,
|
||||
rerank_score: score,
|
||||
...documents[i],
|
||||
}))
|
||||
.sort((a, b) => b.rerank_score - a.rerank_score)
|
||||
.slice(0, options.topK);
|
||||
|
||||
this.log(
|
||||
`Reranking ${documents.length} documents to top ${options.topK} took ${Date.now() - start}ms`
|
||||
);
|
||||
return reranked;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
NativeEmbeddingReranker,
|
||||
};
|
|
@ -95,6 +95,7 @@ const memory = {
|
|||
input: query,
|
||||
LLMConnector,
|
||||
topN: workspace?.topN ?? 4,
|
||||
rerank: workspace?.vectorSearchMode === "rerank",
|
||||
});
|
||||
|
||||
if (contextTexts.length === 0) {
|
||||
|
|
|
@ -180,6 +180,7 @@ async function chatSync({
|
|||
similarityThreshold: workspace?.similarityThreshold,
|
||||
topN: workspace?.topN,
|
||||
filterIdentifiers: pinnedDocIdentifiers,
|
||||
rerank: workspace?.vectorSearchMode === "rerank",
|
||||
})
|
||||
: {
|
||||
contextTexts: [],
|
||||
|
@ -480,6 +481,7 @@ async function streamChat({
|
|||
similarityThreshold: workspace?.similarityThreshold,
|
||||
topN: workspace?.topN,
|
||||
filterIdentifiers: pinnedDocIdentifiers,
|
||||
rerank: workspace?.vectorSearchMode === "rerank",
|
||||
})
|
||||
: {
|
||||
contextTexts: [],
|
||||
|
|
|
@ -93,6 +93,7 @@ async function streamChatWithForEmbed(
|
|||
similarityThreshold: embed.workspace?.similarityThreshold,
|
||||
topN: embed.workspace?.topN,
|
||||
filterIdentifiers: pinnedDocIdentifiers,
|
||||
rerank: embed.workspace?.vectorSearchMode === "rerank",
|
||||
})
|
||||
: {
|
||||
contextTexts: [],
|
||||
|
|
|
@ -89,6 +89,7 @@ async function chatSync({
|
|||
similarityThreshold: workspace?.similarityThreshold,
|
||||
topN: workspace?.topN,
|
||||
filterIdentifiers: pinnedDocIdentifiers,
|
||||
rerank: workspace?.vectorSearchMode === "rerank",
|
||||
})
|
||||
: {
|
||||
contextTexts: [],
|
||||
|
@ -304,6 +305,7 @@ async function streamChat({
|
|||
similarityThreshold: workspace?.similarityThreshold,
|
||||
topN: workspace?.topN,
|
||||
filterIdentifiers: pinnedDocIdentifiers,
|
||||
rerank: workspace?.vectorSearchMode === "rerank",
|
||||
})
|
||||
: {
|
||||
contextTexts: [],
|
||||
|
|
|
@ -139,6 +139,7 @@ async function streamChatWithWorkspace(
|
|||
similarityThreshold: workspace?.similarityThreshold,
|
||||
topN: workspace?.topN,
|
||||
filterIdentifiers: pinnedDocIdentifiers,
|
||||
rerank: workspace?.vectorSearchMode === "rerank",
|
||||
})
|
||||
: {
|
||||
contextTexts: [],
|
||||
|
|
|
@ -56,6 +56,7 @@
|
|||
* @property {Function} totalVectors - Returns the total number of vectors in the database.
|
||||
* @property {Function} namespaceCount - Returns the count of vectors in a given namespace.
|
||||
* @property {Function} similarityResponse - Performs a similarity search on a given namespace.
|
||||
* @property {Function} rerankedSimilarityResponse - Performs a similarity search on a given namespace with reranking (if supported by provider).
|
||||
* @property {Function} namespace - Retrieves the specified namespace collection.
|
||||
* @property {Function} hasNamespace - Checks if a namespace exists.
|
||||
* @property {Function} namespaceExists - Verifies if a namespace exists in the client.
|
||||
|
|
|
@ -5,6 +5,7 @@ const { SystemSettings } = require("../../../models/systemSettings");
|
|||
const { storeVectorResult, cachedVectorInformation } = require("../../files");
|
||||
const { v4: uuidv4 } = require("uuid");
|
||||
const { sourceIdentifier } = require("../../chats");
|
||||
const { NativeEmbeddingReranker } = require("../../EmbeddingRerankers/native");
|
||||
|
||||
/**
|
||||
* LancedDB Client connection object
|
||||
|
@ -57,6 +58,91 @@ const LanceDb = {
|
|||
const table = await client.openTable(_namespace);
|
||||
return (await table.countRows()) || 0;
|
||||
},
|
||||
/**
|
||||
* Performs a SimilaritySearch + Reranking on a namespace.
|
||||
* @param {Object} params - The parameters for the rerankedSimilarityResponse.
|
||||
* @param {Object} params.client - The vectorDB client.
|
||||
* @param {string} params.namespace - The namespace to search in.
|
||||
* @param {string} params.query - The query to search for (plain text).
|
||||
* @param {number[]} params.queryVector - The vector of the query.
|
||||
* @param {number} params.similarityThreshold - The threshold for similarity.
|
||||
* @param {number} params.topN - the number of results to return from this process.
|
||||
* @param {string[]} params.filterIdentifiers - The identifiers of the documents to filter out.
|
||||
* @returns
|
||||
*/
|
||||
rerankedSimilarityResponse: async function ({
|
||||
client,
|
||||
namespace,
|
||||
query,
|
||||
queryVector,
|
||||
topN = 4,
|
||||
similarityThreshold = 0.25,
|
||||
filterIdentifiers = [],
|
||||
}) {
|
||||
const reranker = new NativeEmbeddingReranker();
|
||||
const collection = await client.openTable(namespace);
|
||||
const totalEmbeddings = await this.namespaceCount(namespace);
|
||||
const result = {
|
||||
contextTexts: [],
|
||||
sourceDocuments: [],
|
||||
scores: [],
|
||||
};
|
||||
|
||||
/**
|
||||
* For reranking, we want to work with a larger number of results than the topN.
|
||||
* This is because the reranker can only rerank the results it it given and we dont auto-expand the results.
|
||||
* We want to give the reranker a larger number of results to work with.
|
||||
*
|
||||
* However, we cannot make this boundless as reranking is expensive and time consuming.
|
||||
* So we limit the number of results to a maximum of 50 and a minimum of 10.
|
||||
* This is a good balance between the number of results to rerank and the cost of reranking
|
||||
* and ensures workspaces with 10K embeddings will still rerank within a reasonable timeframe on base level hardware.
|
||||
*
|
||||
* Benchmarks:
|
||||
* On Intel Mac: 2.6 GHz 6-Core Intel Core i7 - 20 docs reranked in ~5.2 sec
|
||||
*/
|
||||
const searchLimit = Math.max(
|
||||
10,
|
||||
Math.min(50, Math.ceil(totalEmbeddings * 0.1))
|
||||
);
|
||||
const vectorSearchResults = await collection
|
||||
.vectorSearch(queryVector)
|
||||
.distanceType("cosine")
|
||||
.limit(searchLimit)
|
||||
.toArray();
|
||||
|
||||
await reranker
|
||||
.rerank(query, vectorSearchResults, { topK: topN })
|
||||
.then((rerankResults) => {
|
||||
rerankResults.forEach((item) => {
|
||||
if (this.distanceToSimilarity(item._distance) < similarityThreshold)
|
||||
return;
|
||||
const { vector: _, ...rest } = item;
|
||||
if (filterIdentifiers.includes(sourceIdentifier(rest))) {
|
||||
console.log(
|
||||
"LanceDB: A source was filtered from context as it's parent document is pinned."
|
||||
);
|
||||
return;
|
||||
}
|
||||
const score =
|
||||
item?.rerank_score || this.distanceToSimilarity(item._distance);
|
||||
|
||||
result.contextTexts.push(rest.text);
|
||||
result.sourceDocuments.push({
|
||||
...rest,
|
||||
score,
|
||||
});
|
||||
result.scores.push(score);
|
||||
});
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error(e);
|
||||
console.error("LanceDB::rerankedSimilarityResponse", e.message);
|
||||
});
|
||||
|
||||
return result;
|
||||
},
|
||||
|
||||
/**
|
||||
* Performs a SimilaritySearch on a give LanceDB namespace.
|
||||
* @param {Object} params
|
||||
|
@ -300,6 +386,7 @@ const LanceDb = {
|
|||
similarityThreshold = 0.25,
|
||||
topN = 4,
|
||||
filterIdentifiers = [],
|
||||
rerank = false,
|
||||
}) {
|
||||
if (!namespace || !input || !LLMConnector)
|
||||
throw new Error("Invalid request to performSimilaritySearch.");
|
||||
|
@ -314,15 +401,26 @@ const LanceDb = {
|
|||
}
|
||||
|
||||
const queryVector = await LLMConnector.embedTextInput(input);
|
||||
const { contextTexts, sourceDocuments } = await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
const result = rerank
|
||||
? await this.rerankedSimilarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
query: input,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
})
|
||||
: await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
|
||||
const { contextTexts, sourceDocuments } = result;
|
||||
const sources = sourceDocuments.map((metadata, i) => {
|
||||
return { metadata: { ...metadata, text: contextTexts[i] } };
|
||||
});
|
||||
|
|
Loading…
Add table
Reference in a new issue