mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Drop embeddings of deleted text entries from index
Previously the deleted embeddings would continue to be in the index, even after the entry was deleted
This commit is contained in:
parent
c73feebf25
commit
ef6a0044f4
2 changed files with 54 additions and 17 deletions
|
@ -65,7 +65,8 @@ def compute_embeddings(
|
|||
normalize=True,
|
||||
):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
new_entries = []
|
||||
new_embeddings = torch.tensor([], device=state.device)
|
||||
existing_embeddings = torch.tensor([], device=state.device)
|
||||
create_index_msg = ""
|
||||
# Load pre-computed embeddings from file if exists and update them if required
|
||||
if embeddings_file.exists() and not regenerate:
|
||||
|
@ -82,15 +83,16 @@ def compute_embeddings(
|
|||
new_embeddings = bi_encoder.encode(
|
||||
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
|
||||
)
|
||||
|
||||
# Extract existing embeddings from previous corpus embeddings
|
||||
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
|
||||
if existing_entry_ids:
|
||||
existing_embeddings = torch.index_select(
|
||||
corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)
|
||||
)
|
||||
else:
|
||||
existing_embeddings = torch.tensor([], device=state.device)
|
||||
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
||||
|
||||
# Set corpus embeddings to merger of existing and new embeddings
|
||||
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
||||
if normalize:
|
||||
# Normalize embeddings for faster lookup via dot product when querying
|
||||
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||
|
|
|
@ -71,8 +71,8 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, sea
|
|||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
assert "📩 Saved computed text embeddings to" in initial_logs
|
||||
assert "📩 Saved computed text embeddings to" not in final_logs
|
||||
assert "Creating index from scratch." in initial_logs
|
||||
assert "Creating index from scratch." not in final_logs
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -192,6 +192,41 @@ def test_update_index_with_duplicate_entries_in_stable_order(
|
|||
pytest.fail(error_details)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
|
||||
# Arrange
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
|
||||
# Insert org-mode entries with same compiled form into new org file
|
||||
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"{new_entry}{new_entry} -- Tatooine")
|
||||
|
||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||
initial_index = text_search.setup(
|
||||
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
|
||||
# update embeddings, entries, notes model after removing an entry from the org file
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"{new_entry}")
|
||||
|
||||
# Act
|
||||
updated_index = text_search.setup(
|
||||
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert len(initial_index.entries) == len(updated_index.entries) + 1
|
||||
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1
|
||||
|
||||
# verify the same entry is added even when there are multiple duplicate entries
|
||||
error_details = compare_index(updated_index, initial_index)
|
||||
if error_details:
|
||||
pytest.fail(error_details)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
||||
# Arrange
|
||||
|
|
Loading…
Reference in a new issue