Drop embeddings of deleted text entries from index

Previously the deleted embeddings would continue to be in the index,
even after the entry was deleted
This commit is contained in:
Debanjum Singh Solanky 2023-07-16 03:47:05 -07:00
parent c73feebf25
commit ef6a0044f4
2 changed files with 54 additions and 17 deletions

View file

@ -65,7 +65,8 @@ def compute_embeddings(
normalize=True,
):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
new_entries = []
new_embeddings = torch.tensor([], device=state.device)
existing_embeddings = torch.tensor([], device=state.device)
create_index_msg = ""
# Load pre-computed embeddings from file if exists and update them if required
if embeddings_file.exists() and not regenerate:
@ -82,22 +83,23 @@ def compute_embeddings(
new_embeddings = bi_encoder.encode(
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
)
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
if existing_entry_ids:
existing_embeddings = torch.index_select(
corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)
)
else:
existing_embeddings = torch.tensor([], device=state.device)
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
if normalize:
# Normalize embeddings for faster lookup via dot product when querying
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
# Extract existing embeddings from previous corpus embeddings
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
if existing_entry_ids:
existing_embeddings = torch.index_select(
corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)
)
# Save regenerated or updated embeddings to file
torch.save(corpus_embeddings, embeddings_file)
logger.info(f"📩 Saved computed text embeddings to {embeddings_file}")
# Set corpus embeddings to merger of existing and new embeddings
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
if normalize:
# Normalize embeddings for faster lookup via dot product when querying
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
# Save regenerated or updated embeddings to file
torch.save(corpus_embeddings, embeddings_file)
logger.info(f"📩 Saved computed text embeddings to {embeddings_file}")
return corpus_embeddings

View file

@ -71,8 +71,8 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, sea
final_logs = caplog.text
# Assert
assert "📩 Saved computed text embeddings to" in initial_logs
assert "📩 Saved computed text embeddings to" not in final_logs
assert "Creating index from scratch." in initial_logs
assert "Creating index from scratch." not in final_logs
# ----------------------------------------------------------------------------------------------------
@ -192,6 +192,41 @@ def test_update_index_with_duplicate_entries_in_stable_order(
pytest.fail(error_details)
# ----------------------------------------------------------------------------------------------------
def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
# Arrange
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
# Insert org-mode entries with same compiled form into new org file
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}{new_entry} -- Tatooine")
# load embeddings, entries, notes model after adding new org file with 2 entries
initial_index = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
)
# update embeddings, entries, notes model after removing an entry from the org file
with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}")
# Act
updated_index = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert len(initial_index.entries) == len(updated_index.entries) + 1
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1
# verify the same entry is added even when there are multiple duplicate entries
error_details = compare_index(updated_index, initial_index)
if error_details:
pytest.fail(error_details)
# ----------------------------------------------------------------------------------------------------
def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
# Arrange