diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index edc735f2..ed3be33c 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -62,36 +62,32 @@ def compute_embeddings( ): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" new_entries = [] + create_index_msg = "" # Load pre-computed embeddings from file if exists and update them if required if embeddings_file.exists() and not regenerate: corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device) logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}") - - # Encode any new entries in the corpus and update corpus embeddings - new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1] - if new_entries: - logger.info(f"📩 Indexing {len(new_entries)} text entries.") - new_embeddings = bi_encoder.encode( - new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True - ) - existing_entry_ids = [id for id, _ in entries_with_ids if id != -1] - if existing_entry_ids: - existing_embeddings = torch.index_select( - corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device) - ) - else: - existing_embeddings = torch.tensor([], device=state.device) - corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) - # Else compute the corpus embeddings from scratch else: - new_entries = [entry.compiled for _, entry in entries_with_ids] - logger.info(f"📩 Indexing {len(new_entries)} text entries. Creating index from scratch.") - corpus_embeddings = bi_encoder.encode( + corpus_embeddings = torch.tensor([], device=state.device) + create_index_msg = " Creating index from scratch." + + # Encode any new entries in the corpus and update corpus embeddings + new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1] + if new_entries: + logger.info(f"📩 Indexing {len(new_entries)} text entries.{create_index_msg}") + new_embeddings = bi_encoder.encode( new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True ) + existing_entry_ids = [id for id, _ in entries_with_ids if id != -1] + if existing_entry_ids: + existing_embeddings = torch.index_select( + corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device) + ) + else: + existing_embeddings = torch.tensor([], device=state.device) + corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) - # Save regenerated or updated embeddings to file - if new_entries: + # Save regenerated or updated embeddings to file corpus_embeddings = util.normalize_embeddings(corpus_embeddings) torch.save(corpus_embeddings, embeddings_file) logger.info(f"📩 Saved computed text embeddings to {embeddings_file}")