Fix all unit tests for test_text_search

This commit is contained in:
sabaimran 2024-03-15 16:39:03 +05:30 committed by Debanjum Singh Solanky
parent 44b3247869
commit 720139c3c1

View file

@ -57,18 +57,21 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_search_setup_with_empty_file_creates_no_entries(
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser
):
# Arrange
existing_entries = Entry.objects.filter(user=default_user).count()
data = get_org_files(org_config_with_only_new_file)
# Act
# Generate notes embeddings during asymmetric setup
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# Assert
assert "Deleted 8 entries. Created 0 new entries for user " in caplog.records[-1].message
updated_entries = Entry.objects.filter(user=default_user).count()
assert existing_entries == 2
assert updated_entries == 0
verify_embeddings(0, default_user)
@ -78,6 +81,7 @@ def test_text_indexer_deletes_embedding_before_regenerate(
content_config: ContentConfig, default_user: KhojUser, caplog
):
# Arrange
existing_entries = Entry.objects.filter(user=default_user).count()
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
@ -87,30 +91,18 @@ def test_text_indexer_deletes_embedding_before_regenerate(
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# Assert
updated_entries = Entry.objects.filter(user=default_user).count()
assert existing_entries == 2
assert updated_entries == 2
assert "Deleting all entries for file type org" in caplog.text
assert "Deleted 8 entries. Created 13 new entries for user " in caplog.records[-1].message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_search_setup_batch_processes(content_config: ContentConfig, default_user: KhojUser, caplog):
# Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
# Act
# Generate notes embeddings during asymmetric setup
with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# Assert
assert "Deleted 8 entries. Created 13 new entries for user " in caplog.records[-1].message
assert "Deleted 2 entries. Created 2 new entries for user " in caplog.records[-1].message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
# Arrange
existing_entries = Entry.objects.filter(user=default_user)
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
@ -127,6 +119,10 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
final_logs = caplog.text
# Assert
updated_entries = Entry.objects.filter(user=default_user)
for entry in updated_entries:
assert entry in existing_entries
assert len(existing_entries) == len(updated_entries)
assert "Deleting all entries for file type org" in initial_logs
assert "Deleting all entries for file type org" not in final_logs
@ -256,10 +252,9 @@ conda activate khoj
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_regenerate_index_with_new_entry(
content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog
):
def test_regenerate_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser):
# Arrange
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
initial_data = get_org_files(org_config)
@ -271,28 +266,34 @@ def test_regenerate_index_with_new_entry(
final_data = get_org_files(org_config)
# Act
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
# regenerate notes jsonl, model embeddings and model to include entry from new file
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
final_logs = caplog.text
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
# Assert
assert "Deleted 8 entries. Created 13 new entries for user " in initial_logs
assert "Deleted 13 entries. Created 14 new entries for user " in final_logs
verify_embeddings(14, default_user)
for entry in updated_entries1:
assert entry in updated_entries2
assert not any([new_org_file.name in entry for entry in updated_entries1])
assert not any([new_org_file.name in entry for entry in existing_entries])
assert any([new_org_file.name in entry for entry in updated_entries2])
assert any(
["Saw a super cute video of a chihuahua doing the Tango on Youtube" in entry for entry in updated_entries2]
)
verify_embeddings(3, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_update_index_with_duplicate_entries_in_stable_order(
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser
):
# Arrange
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
# Insert org-mode entries with same compiled form into new org file
@ -304,30 +305,33 @@ def test_update_index_with_duplicate_entries_in_stable_order(
# Act
# generate embeddings, entries, notes model from scratch after adding new org-mode file
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
data = get_org_files(org_config_with_only_new_file)
# update embeddings, entries, notes model with no new changes
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
final_logs = caplog.text
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert "Deleted 8 entries. Created 1 new entries for user " in initial_logs
assert "Deleted 0 entries. Created 0 new entries for user " in final_logs
for entry in existing_entries:
assert entry not in updated_entries1
for entry in updated_entries1:
assert entry in updated_entries2
assert len(existing_entries) == 2
assert len(updated_entries1) == len(updated_entries2)
verify_embeddings(1, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser):
# Arrange
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
# Insert org-mode entries with same compiled form into new org file
@ -344,33 +348,34 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
# Act
# load embeddings, entries, notes model after adding new org file with 2 entries
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
final_logs = caplog.text
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert "Deleted 8 entries. Created 2 new entries for user " in initial_logs
assert "Deleted 1 entries. Created 0 new entries for user " in final_logs
for entry in existing_entries:
assert entry not in updated_entries1
# verify the entry in updated_entries2 is a subset of updated_entries1
for entry in updated_entries1:
assert entry not in updated_entries2
for entry in updated_entries2:
assert entry in updated_entries1[0]
verify_embeddings(1, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog):
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser):
# Arrange
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# append org-mode entry to first org input file in config
with open(new_org_file, "w") as f:
@ -381,14 +386,14 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
# Act
# update embeddings, entries with the newly added note
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
final_logs = caplog.text
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
# Assert
assert "Deleted 8 entries. Created 13 new entries for user " in initial_logs
assert "Deleted 0 entries. Created 1 new entries for user " in final_logs
verify_embeddings(14, default_user)
for entry in existing_entries:
assert entry not in updated_entries1
assert len(updated_entries1) == len(existing_entries) + 1
verify_embeddings(3, default_user)
# ----------------------------------------------------------------------------------------------------