Save, Load Embeddings to/from file to speed up script load time

This commit is contained in:
Debanjum Singh Solanky 2021-07-31 02:58:34 -07:00
parent 0914f284bb
commit eb03f57917

View file

@ -6,6 +6,7 @@ import time
import gzip
import os
import sys
import torch
# We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
model_name = 'msmarco-MiniLM-L-6-v3'
@ -27,8 +28,13 @@ with gzip.open(notes_filepath, 'rt', encoding='utf8') as fIn:
print(f"Passages: {len(passages)}")
# Here, we compute the corpus_embeddings from scratch (which can take a while depending on the GPU)
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
embeddings_filename = 'notes_embeddings.pt'
# Load pre-computed embeddings from file if exists
if os.path.exists(embeddings_filename):
corpus_embeddings = torch.load(embeddings_filename)
else: # Else compute the corpus_embeddings from scratch, which can take a while
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
torch.save(corpus_embeddings, 'notes_embeddings.pt')
# This function will search all notes for passages that answer the query
def search(query):