2025-01-02 14:27:52 -08:00
|
|
|
const path = require("path");
|
|
|
|
const fs = require("fs");
|
|
|
|
|
|
|
|
class NativeEmbeddingReranker {
|
|
|
|
static #model = null;
|
|
|
|
static #tokenizer = null;
|
|
|
|
static #transformers = null;
|
|
|
|
|
2025-01-07 15:09:54 -08:00
|
|
|
// This is a folder that Mintplex Labs hosts for those who cannot capture the HF model download
|
|
|
|
// endpoint for various reasons. This endpoint is not guaranteed to be active or maintained
|
|
|
|
// and may go offline at any time at Mintplex Labs's discretion.
|
2025-02-05 10:30:43 -08:00
|
|
|
#fallbackHost = "https://cdn.anythingllm.com/support/models/";
|
2025-01-07 15:09:54 -08:00
|
|
|
|
2025-01-02 14:27:52 -08:00
|
|
|
constructor() {
|
|
|
|
// An alternative model to the mixedbread-ai/mxbai-rerank-xsmall-v1 model (speed on CPU is much slower for this model @ 18docs = 6s)
|
|
|
|
// Model Card: https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2 (speed on CPU is much faster @ 18docs = 1.6s)
|
|
|
|
this.model = "Xenova/ms-marco-MiniLM-L-6-v2";
|
|
|
|
this.cacheDir = path.resolve(
|
|
|
|
process.env.STORAGE_DIR
|
|
|
|
? path.resolve(process.env.STORAGE_DIR, `models`)
|
|
|
|
: path.resolve(__dirname, `../../../storage/models`)
|
|
|
|
);
|
|
|
|
this.modelPath = path.resolve(this.cacheDir, ...this.model.split("/"));
|
|
|
|
// Make directory when it does not exist in existing installations
|
|
|
|
if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir);
|
2025-01-07 15:09:54 -08:00
|
|
|
|
|
|
|
this.modelDownloaded = fs.existsSync(
|
|
|
|
path.resolve(this.cacheDir, this.model)
|
|
|
|
);
|
2025-01-02 14:27:52 -08:00
|
|
|
this.log("Initialized");
|
|
|
|
}
|
|
|
|
|
|
|
|
log(text, ...args) {
|
|
|
|
console.log(`\x1b[36m[NativeEmbeddingReranker]\x1b[0m ${text}`, ...args);
|
|
|
|
}
|
|
|
|
|
2025-01-07 15:09:54 -08:00
|
|
|
/**
|
|
|
|
* This function will return the host of the current reranker suite.
|
|
|
|
* If the reranker suite is not initialized, it will return the default HF host.
|
|
|
|
* @returns {string} The host of the current reranker suite.
|
|
|
|
*/
|
|
|
|
get host() {
|
|
|
|
if (!NativeEmbeddingReranker.#transformers) return "https://huggingface.co";
|
|
|
|
try {
|
|
|
|
return new URL(NativeEmbeddingReranker.#transformers.env.remoteHost).host;
|
|
|
|
} catch (e) {
|
|
|
|
return this.#fallbackHost;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2025-01-02 14:27:52 -08:00
|
|
|
/**
|
|
|
|
* This function will preload the reranker suite and tokenizer.
|
|
|
|
* This is useful for reducing the latency of the first rerank call and pre-downloading the models and such
|
|
|
|
* to avoid having to wait for the models to download on the first rerank call.
|
|
|
|
*/
|
|
|
|
async preload() {
|
|
|
|
try {
|
|
|
|
this.log(`Preloading reranker suite...`);
|
|
|
|
await this.initClient();
|
|
|
|
this.log(
|
|
|
|
`Preloaded reranker suite. Reranking is available as a service now.`
|
|
|
|
);
|
|
|
|
return;
|
|
|
|
} catch (e) {
|
|
|
|
console.error(e);
|
|
|
|
this.log(
|
|
|
|
`Failed to preload reranker suite. Reranking will be available on the first rerank call.`
|
|
|
|
);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
async initClient() {
|
|
|
|
if (NativeEmbeddingReranker.#transformers) {
|
|
|
|
this.log(`Reranker suite already initialized - reusing.`);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
await import("@xenova/transformers").then(
|
2025-01-07 15:09:54 -08:00
|
|
|
async ({ AutoModelForSequenceClassification, AutoTokenizer, env }) => {
|
2025-01-02 14:27:52 -08:00
|
|
|
this.log(`Loading reranker suite...`);
|
|
|
|
NativeEmbeddingReranker.#transformers = {
|
|
|
|
AutoModelForSequenceClassification,
|
|
|
|
AutoTokenizer,
|
2025-01-07 15:09:54 -08:00
|
|
|
env,
|
2025-01-02 14:27:52 -08:00
|
|
|
};
|
2025-01-07 15:09:54 -08:00
|
|
|
// Attempt to load the model and tokenizer in this order:
|
|
|
|
// 1. From local file system cache
|
|
|
|
// 2. Download and cache from remote host (hf.co)
|
2025-02-05 10:30:43 -08:00
|
|
|
// 3. Download and cache from fallback host (cdn.anythingllm.com)
|
2025-01-02 14:27:52 -08:00
|
|
|
await this.#getPreTrainedModel();
|
|
|
|
await this.#getPreTrainedTokenizer();
|
|
|
|
}
|
|
|
|
);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2025-01-07 15:09:54 -08:00
|
|
|
/**
|
|
|
|
* This function will load the model from the local file system cache, or download and cache it from the remote host.
|
|
|
|
* If the model is not found in the local file system cache, it will download and cache it from the remote host.
|
|
|
|
* If the model is not found in the remote host, it will download and cache it from the fallback host.
|
|
|
|
* @returns {Promise<any>} The loaded model.
|
|
|
|
*/
|
2025-01-02 14:27:52 -08:00
|
|
|
async #getPreTrainedModel() {
|
|
|
|
if (NativeEmbeddingReranker.#model) {
|
|
|
|
this.log(`Loading model from singleton...`);
|
|
|
|
return NativeEmbeddingReranker.#model;
|
|
|
|
}
|
|
|
|
|
2025-01-07 15:09:54 -08:00
|
|
|
try {
|
|
|
|
const model =
|
|
|
|
await NativeEmbeddingReranker.#transformers.AutoModelForSequenceClassification.from_pretrained(
|
|
|
|
this.model,
|
|
|
|
{
|
|
|
|
progress_callback: (p) => {
|
|
|
|
if (!this.modelDownloaded && p.status === "progress") {
|
|
|
|
this.log(
|
|
|
|
`[${this.host}] Loading model ${this.model}... ${p?.progress}%`
|
|
|
|
);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
cache_dir: this.cacheDir,
|
|
|
|
}
|
|
|
|
);
|
|
|
|
this.log(`Loaded model ${this.model}`);
|
|
|
|
NativeEmbeddingReranker.#model = model;
|
|
|
|
return model;
|
|
|
|
} catch (e) {
|
|
|
|
this.log(
|
|
|
|
`Failed to load model ${this.model} from ${this.host}.`,
|
|
|
|
e.message,
|
|
|
|
e.stack
|
2025-01-02 14:27:52 -08:00
|
|
|
);
|
2025-01-07 15:09:54 -08:00
|
|
|
if (
|
|
|
|
NativeEmbeddingReranker.#transformers.env.remoteHost ===
|
|
|
|
this.#fallbackHost
|
|
|
|
) {
|
|
|
|
this.log(`Failed to load model ${this.model} from fallback host.`);
|
|
|
|
throw e;
|
|
|
|
}
|
|
|
|
|
|
|
|
this.log(`Falling back to fallback host. ${this.#fallbackHost}`);
|
|
|
|
NativeEmbeddingReranker.#transformers.env.remoteHost = this.#fallbackHost;
|
|
|
|
NativeEmbeddingReranker.#transformers.env.remotePathTemplate = "{model}/";
|
|
|
|
return await this.#getPreTrainedModel();
|
|
|
|
}
|
2025-01-02 14:27:52 -08:00
|
|
|
}
|
|
|
|
|
2025-01-07 15:09:54 -08:00
|
|
|
/**
|
|
|
|
* This function will load the tokenizer from the local file system cache, or download and cache it from the remote host.
|
|
|
|
* If the tokenizer is not found in the local file system cache, it will download and cache it from the remote host.
|
|
|
|
* If the tokenizer is not found in the remote host, it will download and cache it from the fallback host.
|
|
|
|
* @returns {Promise<any>} The loaded tokenizer.
|
|
|
|
*/
|
2025-01-02 14:27:52 -08:00
|
|
|
async #getPreTrainedTokenizer() {
|
|
|
|
if (NativeEmbeddingReranker.#tokenizer) {
|
|
|
|
this.log(`Loading tokenizer from singleton...`);
|
|
|
|
return NativeEmbeddingReranker.#tokenizer;
|
|
|
|
}
|
|
|
|
|
2025-01-07 15:09:54 -08:00
|
|
|
try {
|
|
|
|
const tokenizer =
|
|
|
|
await NativeEmbeddingReranker.#transformers.AutoTokenizer.from_pretrained(
|
|
|
|
this.model,
|
|
|
|
{
|
|
|
|
progress_callback: (p) => {
|
|
|
|
if (!this.modelDownloaded && p.status === "progress") {
|
|
|
|
this.log(
|
|
|
|
`[${this.host}] Loading tokenizer ${this.model}... ${p?.progress}%`
|
|
|
|
);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
cache_dir: this.cacheDir,
|
|
|
|
}
|
|
|
|
);
|
|
|
|
this.log(`Loaded tokenizer ${this.model}`);
|
|
|
|
NativeEmbeddingReranker.#tokenizer = tokenizer;
|
|
|
|
return tokenizer;
|
|
|
|
} catch (e) {
|
|
|
|
this.log(
|
|
|
|
`Failed to load tokenizer ${this.model} from ${this.host}.`,
|
|
|
|
e.message,
|
|
|
|
e.stack
|
2025-01-02 14:27:52 -08:00
|
|
|
);
|
2025-01-07 15:09:54 -08:00
|
|
|
if (
|
|
|
|
NativeEmbeddingReranker.#transformers.env.remoteHost ===
|
|
|
|
this.#fallbackHost
|
|
|
|
) {
|
|
|
|
this.log(`Failed to load tokenizer ${this.model} from fallback host.`);
|
|
|
|
throw e;
|
|
|
|
}
|
|
|
|
|
|
|
|
this.log(`Falling back to fallback host. ${this.#fallbackHost}`);
|
|
|
|
NativeEmbeddingReranker.#transformers.env.remoteHost = this.#fallbackHost;
|
|
|
|
NativeEmbeddingReranker.#transformers.env.remotePathTemplate = "{model}/";
|
|
|
|
return await this.#getPreTrainedTokenizer();
|
|
|
|
}
|
2025-01-02 14:27:52 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Reranks a list of documents based on the query.
|
|
|
|
* @param {string} query - The query to rerank the documents against.
|
|
|
|
* @param {{text: string}[]} documents - The list of document text snippets to rerank. Should be output from a vector search.
|
|
|
|
* @param {Object} options - The options for the reranking.
|
|
|
|
* @param {number} options.topK - The number of top documents to return.
|
|
|
|
* @returns {Promise<any[]>} - The reranked list of documents.
|
|
|
|
*/
|
|
|
|
async rerank(query, documents, options = { topK: 4 }) {
|
|
|
|
await this.initClient();
|
|
|
|
const model = NativeEmbeddingReranker.#model;
|
|
|
|
const tokenizer = NativeEmbeddingReranker.#tokenizer;
|
|
|
|
|
|
|
|
const start = Date.now();
|
|
|
|
this.log(`Reranking ${documents.length} documents...`);
|
|
|
|
const inputs = tokenizer(new Array(documents.length).fill(query), {
|
|
|
|
text_pair: documents.map((doc) => doc.text),
|
|
|
|
padding: true,
|
|
|
|
truncation: true,
|
|
|
|
});
|
|
|
|
const { logits } = await model(inputs);
|
|
|
|
const reranked = logits
|
|
|
|
.sigmoid()
|
|
|
|
.tolist()
|
|
|
|
.map(([score], i) => ({
|
|
|
|
rerank_corpus_id: i,
|
|
|
|
rerank_score: score,
|
|
|
|
...documents[i],
|
|
|
|
}))
|
|
|
|
.sort((a, b) => b.rerank_score - a.rerank_score)
|
|
|
|
.slice(0, options.topK);
|
|
|
|
|
|
|
|
this.log(
|
|
|
|
`Reranking ${documents.length} documents to top ${options.topK} took ${Date.now() - start}ms`
|
|
|
|
);
|
|
|
|
return reranked;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
module.exports = {
|
|
|
|
NativeEmbeddingReranker,
|
|
|
|
};
|