mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Save, Load Embeddings to/from file to speed up script load time
This commit is contained in:
parent
0914f284bb
commit
eb03f57917
1 changed files with 8 additions and 2 deletions
|
@ -6,6 +6,7 @@ import time
|
|||
import gzip
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
|
||||
# We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
|
||||
model_name = 'msmarco-MiniLM-L-6-v3'
|
||||
|
@ -27,8 +28,13 @@ with gzip.open(notes_filepath, 'rt', encoding='utf8') as fIn:
|
|||
|
||||
print(f"Passages: {len(passages)}")
|
||||
|
||||
# Here, we compute the corpus_embeddings from scratch (which can take a while depending on the GPU)
|
||||
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
|
||||
embeddings_filename = 'notes_embeddings.pt'
|
||||
# Load pre-computed embeddings from file if exists
|
||||
if os.path.exists(embeddings_filename):
|
||||
corpus_embeddings = torch.load(embeddings_filename)
|
||||
else: # Else compute the corpus_embeddings from scratch, which can take a while
|
||||
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
|
||||
torch.save(corpus_embeddings, 'notes_embeddings.pt')
|
||||
|
||||
# This function will search all notes for passages that answer the query
|
||||
def search(query):
|
||||
|
|
Loading…
Reference in a new issue