mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Save, Load Embeddings to/from file to speed up script load time
This commit is contained in:
parent
0914f284bb
commit
eb03f57917
1 changed files with 8 additions and 2 deletions
|
@ -6,6 +6,7 @@ import time
|
||||||
import gzip
|
import gzip
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import torch
|
||||||
|
|
||||||
# We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
|
# We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
|
||||||
model_name = 'msmarco-MiniLM-L-6-v3'
|
model_name = 'msmarco-MiniLM-L-6-v3'
|
||||||
|
@ -27,8 +28,13 @@ with gzip.open(notes_filepath, 'rt', encoding='utf8') as fIn:
|
||||||
|
|
||||||
print(f"Passages: {len(passages)}")
|
print(f"Passages: {len(passages)}")
|
||||||
|
|
||||||
# Here, we compute the corpus_embeddings from scratch (which can take a while depending on the GPU)
|
embeddings_filename = 'notes_embeddings.pt'
|
||||||
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
|
# Load pre-computed embeddings from file if exists
|
||||||
|
if os.path.exists(embeddings_filename):
|
||||||
|
corpus_embeddings = torch.load(embeddings_filename)
|
||||||
|
else: # Else compute the corpus_embeddings from scratch, which can take a while
|
||||||
|
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
|
||||||
|
torch.save(corpus_embeddings, 'notes_embeddings.pt')
|
||||||
|
|
||||||
# This function will search all notes for passages that answer the query
|
# This function will search all notes for passages that answer the query
|
||||||
def search(query):
|
def search(query):
|
||||||
|
|
Loading…
Reference in a new issue