mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Unify logic to generate embeddings from scratch and incrementally
This simplifies the `compute_embeddings' method and avoids potential later divergence in handling the index regenerate vs update scenarios
This commit is contained in:
parent
6a0297cc86
commit
89c7819cb7
1 changed files with 18 additions and 22 deletions
|
@ -62,36 +62,32 @@ def compute_embeddings(
|
||||||
):
|
):
|
||||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||||
new_entries = []
|
new_entries = []
|
||||||
|
create_index_msg = ""
|
||||||
# Load pre-computed embeddings from file if exists and update them if required
|
# Load pre-computed embeddings from file if exists and update them if required
|
||||||
if embeddings_file.exists() and not regenerate:
|
if embeddings_file.exists() and not regenerate:
|
||||||
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||||
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
|
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
|
||||||
|
|
||||||
# Encode any new entries in the corpus and update corpus embeddings
|
|
||||||
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
|
|
||||||
if new_entries:
|
|
||||||
logger.info(f"📩 Indexing {len(new_entries)} text entries.")
|
|
||||||
new_embeddings = bi_encoder.encode(
|
|
||||||
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
|
|
||||||
)
|
|
||||||
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
|
|
||||||
if existing_entry_ids:
|
|
||||||
existing_embeddings = torch.index_select(
|
|
||||||
corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
existing_embeddings = torch.tensor([], device=state.device)
|
|
||||||
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
|
||||||
# Else compute the corpus embeddings from scratch
|
|
||||||
else:
|
else:
|
||||||
new_entries = [entry.compiled for _, entry in entries_with_ids]
|
corpus_embeddings = torch.tensor([], device=state.device)
|
||||||
logger.info(f"📩 Indexing {len(new_entries)} text entries. Creating index from scratch.")
|
create_index_msg = " Creating index from scratch."
|
||||||
corpus_embeddings = bi_encoder.encode(
|
|
||||||
|
# Encode any new entries in the corpus and update corpus embeddings
|
||||||
|
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
|
||||||
|
if new_entries:
|
||||||
|
logger.info(f"📩 Indexing {len(new_entries)} text entries.{create_index_msg}")
|
||||||
|
new_embeddings = bi_encoder.encode(
|
||||||
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
|
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
|
||||||
)
|
)
|
||||||
|
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
|
||||||
|
if existing_entry_ids:
|
||||||
|
existing_embeddings = torch.index_select(
|
||||||
|
corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
existing_embeddings = torch.tensor([], device=state.device)
|
||||||
|
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
||||||
|
|
||||||
# Save regenerated or updated embeddings to file
|
# Save regenerated or updated embeddings to file
|
||||||
if new_entries:
|
|
||||||
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||||
torch.save(corpus_embeddings, embeddings_file)
|
torch.save(corpus_embeddings, embeddings_file)
|
||||||
logger.info(f"📩 Saved computed text embeddings to {embeddings_file}")
|
logger.info(f"📩 Saved computed text embeddings to {embeddings_file}")
|
||||||
|
|
Loading…
Add table
Reference in a new issue