Unify logic to generate embeddings from scratch and incrementally

This simplifies the `compute_embeddings' method and avoids potential
later divergence in handling the index regenerate vs update scenarios
This commit is contained in:
Debanjum Singh Solanky 2023-07-16 00:22:14 -07:00
parent 6a0297cc86
commit 89c7819cb7

View file

@ -62,15 +62,19 @@ def compute_embeddings(
):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
new_entries = []
create_index_msg = ""
# Load pre-computed embeddings from file if exists and update them if required
if embeddings_file.exists() and not regenerate:
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
else:
corpus_embeddings = torch.tensor([], device=state.device)
create_index_msg = " Creating index from scratch."
# Encode any new entries in the corpus and update corpus embeddings
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
if new_entries:
logger.info(f"📩 Indexing {len(new_entries)} text entries.")
logger.info(f"📩 Indexing {len(new_entries)} text entries.{create_index_msg}")
new_embeddings = bi_encoder.encode(
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
)
@ -82,16 +86,8 @@ def compute_embeddings(
else:
existing_embeddings = torch.tensor([], device=state.device)
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
# Else compute the corpus embeddings from scratch
else:
new_entries = [entry.compiled for _, entry in entries_with_ids]
logger.info(f"📩 Indexing {len(new_entries)} text entries. Creating index from scratch.")
corpus_embeddings = bi_encoder.encode(
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
)
# Save regenerated or updated embeddings to file
if new_entries:
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
torch.save(corpus_embeddings, embeddings_file)
logger.info(f"📩 Saved computed text embeddings to {embeddings_file}")