From eb03f57917ab9124f73ba4777b107153f8c05d14 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 31 Jul 2021 02:58:34 -0700 Subject: [PATCH] Save, Load Embeddings to/from file to speed up script load time --- asymmetric.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/asymmetric.py b/asymmetric.py index fd4574cd..b69ea190 100644 --- a/asymmetric.py +++ b/asymmetric.py @@ -6,6 +6,7 @@ import time import gzip import os import sys +import torch # We use the Bi-Encoder to encode all passages, so that we can use it with sematic search model_name = 'msmarco-MiniLM-L-6-v3' @@ -27,8 +28,13 @@ with gzip.open(notes_filepath, 'rt', encoding='utf8') as fIn: print(f"Passages: {len(passages)}") -# Here, we compute the corpus_embeddings from scratch (which can take a while depending on the GPU) -corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True) +embeddings_filename = 'notes_embeddings.pt' +# Load pre-computed embeddings from file if exists +if os.path.exists(embeddings_filename): + corpus_embeddings = torch.load(embeddings_filename) +else: # Else compute the corpus_embeddings from scratch, which can take a while + corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True) + torch.save(corpus_embeddings, 'notes_embeddings.pt') # This function will search all notes for passages that answer the query def search(query):