const path = require("path"); const fs = require("fs"); class NativeEmbeddingReranker { static #model = null; static #tokenizer = null; static #transformers = null; // This is a folder that Mintplex Labs hosts for those who cannot capture the HF model download // endpoint for various reasons. This endpoint is not guaranteed to be active or maintained // and may go offline at any time at Mintplex Labs's discretion. #fallbackHost = "https://cdn.anythingllm.com/support/models/"; constructor() { // An alternative model to the mixedbread-ai/mxbai-rerank-xsmall-v1 model (speed on CPU is much slower for this model @ 18docs = 6s) // Model Card: https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2 (speed on CPU is much faster @ 18docs = 1.6s) this.model = "Xenova/ms-marco-MiniLM-L-6-v2"; this.cacheDir = path.resolve( process.env.STORAGE_DIR ? path.resolve(process.env.STORAGE_DIR, `models`) : path.resolve(__dirname, `../../../storage/models`) ); this.modelPath = path.resolve(this.cacheDir, ...this.model.split("/")); // Make directory when it does not exist in existing installations if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir); this.modelDownloaded = fs.existsSync( path.resolve(this.cacheDir, this.model) ); this.log("Initialized"); } log(text, ...args) { console.log(`\x1b[36m[NativeEmbeddingReranker]\x1b[0m ${text}`, ...args); } /** * This function will return the host of the current reranker suite. * If the reranker suite is not initialized, it will return the default HF host. * @returns {string} The host of the current reranker suite. */ get host() { if (!NativeEmbeddingReranker.#transformers) return "https://huggingface.co"; try { return new URL(NativeEmbeddingReranker.#transformers.env.remoteHost).host; } catch (e) { return this.#fallbackHost; } } /** * This function will preload the reranker suite and tokenizer. * This is useful for reducing the latency of the first rerank call and pre-downloading the models and such * to avoid having to wait for the models to download on the first rerank call. */ async preload() { try { this.log(`Preloading reranker suite...`); await this.initClient(); this.log( `Preloaded reranker suite. Reranking is available as a service now.` ); return; } catch (e) { console.error(e); this.log( `Failed to preload reranker suite. Reranking will be available on the first rerank call.` ); return; } } async initClient() { if (NativeEmbeddingReranker.#transformers) { this.log(`Reranker suite already initialized - reusing.`); return; } await import("@xenova/transformers").then( async ({ AutoModelForSequenceClassification, AutoTokenizer, env }) => { this.log(`Loading reranker suite...`); NativeEmbeddingReranker.#transformers = { AutoModelForSequenceClassification, AutoTokenizer, env, }; // Attempt to load the model and tokenizer in this order: // 1. From local file system cache // 2. Download and cache from remote host (hf.co) // 3. Download and cache from fallback host (cdn.anythingllm.com) await this.#getPreTrainedModel(); await this.#getPreTrainedTokenizer(); } ); return; } /** * This function will load the model from the local file system cache, or download and cache it from the remote host. * If the model is not found in the local file system cache, it will download and cache it from the remote host. * If the model is not found in the remote host, it will download and cache it from the fallback host. * @returns {Promise<any>} The loaded model. */ async #getPreTrainedModel() { if (NativeEmbeddingReranker.#model) { this.log(`Loading model from singleton...`); return NativeEmbeddingReranker.#model; } try { const model = await NativeEmbeddingReranker.#transformers.AutoModelForSequenceClassification.from_pretrained( this.model, { progress_callback: (p) => { if (!this.modelDownloaded && p.status === "progress") { this.log( `[${this.host}] Loading model ${this.model}... ${p?.progress}%` ); } }, cache_dir: this.cacheDir, } ); this.log(`Loaded model ${this.model}`); NativeEmbeddingReranker.#model = model; return model; } catch (e) { this.log( `Failed to load model ${this.model} from ${this.host}.`, e.message, e.stack ); if ( NativeEmbeddingReranker.#transformers.env.remoteHost === this.#fallbackHost ) { this.log(`Failed to load model ${this.model} from fallback host.`); throw e; } this.log(`Falling back to fallback host. ${this.#fallbackHost}`); NativeEmbeddingReranker.#transformers.env.remoteHost = this.#fallbackHost; NativeEmbeddingReranker.#transformers.env.remotePathTemplate = "{model}/"; return await this.#getPreTrainedModel(); } } /** * This function will load the tokenizer from the local file system cache, or download and cache it from the remote host. * If the tokenizer is not found in the local file system cache, it will download and cache it from the remote host. * If the tokenizer is not found in the remote host, it will download and cache it from the fallback host. * @returns {Promise<any>} The loaded tokenizer. */ async #getPreTrainedTokenizer() { if (NativeEmbeddingReranker.#tokenizer) { this.log(`Loading tokenizer from singleton...`); return NativeEmbeddingReranker.#tokenizer; } try { const tokenizer = await NativeEmbeddingReranker.#transformers.AutoTokenizer.from_pretrained( this.model, { progress_callback: (p) => { if (!this.modelDownloaded && p.status === "progress") { this.log( `[${this.host}] Loading tokenizer ${this.model}... ${p?.progress}%` ); } }, cache_dir: this.cacheDir, } ); this.log(`Loaded tokenizer ${this.model}`); NativeEmbeddingReranker.#tokenizer = tokenizer; return tokenizer; } catch (e) { this.log( `Failed to load tokenizer ${this.model} from ${this.host}.`, e.message, e.stack ); if ( NativeEmbeddingReranker.#transformers.env.remoteHost === this.#fallbackHost ) { this.log(`Failed to load tokenizer ${this.model} from fallback host.`); throw e; } this.log(`Falling back to fallback host. ${this.#fallbackHost}`); NativeEmbeddingReranker.#transformers.env.remoteHost = this.#fallbackHost; NativeEmbeddingReranker.#transformers.env.remotePathTemplate = "{model}/"; return await this.#getPreTrainedTokenizer(); } } /** * Reranks a list of documents based on the query. * @param {string} query - The query to rerank the documents against. * @param {{text: string}[]} documents - The list of document text snippets to rerank. Should be output from a vector search. * @param {Object} options - The options for the reranking. * @param {number} options.topK - The number of top documents to return. * @returns {Promise<any[]>} - The reranked list of documents. */ async rerank(query, documents, options = { topK: 4 }) { await this.initClient(); const model = NativeEmbeddingReranker.#model; const tokenizer = NativeEmbeddingReranker.#tokenizer; const start = Date.now(); this.log(`Reranking ${documents.length} documents...`); const inputs = tokenizer(new Array(documents.length).fill(query), { text_pair: documents.map((doc) => doc.text), padding: true, truncation: true, }); const { logits } = await model(inputs); const reranked = logits .sigmoid() .tolist() .map(([score], i) => ({ rerank_corpus_id: i, rerank_score: score, ...documents[i], })) .sort((a, b) => b.rerank_score - a.rerank_score) .slice(0, options.topK); this.log( `Reranking ${documents.length} documents to top ${options.topK} took ${Date.now() - start}ms` ); return reranked; } } module.exports = { NativeEmbeddingReranker, };