Create helper function to test value, order of entries & embeddings

This helper should be used to observe if the current embeddings are
stable sorted on regenerate and incremental update of index in text
search tests
This commit is contained in:
Debanjum Singh Solanky 2023-07-15 14:33:15 -07:00
parent 7ad96036b0
commit da98b92dd4

View file

@ -5,6 +5,7 @@ import os
# External Packages
import pytest
import torch
from khoj.utils.config import SearchModels
# Internal Packages
@ -202,3 +203,25 @@ def test_asymmetric_setup_github(content_config: ContentConfig, search_models: S
# Assert
assert len(github_model.entries) > 1
def compare_index(initial_notes_model, final_notes_model):
mismatched_entries, mismatched_embeddings = [], []
for index in range(len(initial_notes_model.entries)):
if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json():
mismatched_entries.append(index)
# verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings
for index in range(len(initial_notes_model.corpus_embeddings)):
if not torch.equal(final_notes_model.corpus_embeddings[index], initial_notes_model.corpus_embeddings[index]):
mismatched_embeddings.append(index)
error_details = ""
if mismatched_entries:
mismatched_entries_str = ",".join(map(str, mismatched_entries))
error_details += f"Entries at {mismatched_entries_str} not equal\n"
if mismatched_embeddings:
mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings))
error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n"
return error_details