diff --git a/src/main.py b/src/main.py index 069dc903..ce18b6a1 100644 --- a/src/main.py +++ b/src/main.py @@ -11,9 +11,11 @@ from fastapi import FastAPI from search_type import asymmetric, symmetric_ledger, image_search from utils.helpers import get_from_dict from utils.cli import cli -from utils.config import SearchType, SearchSettings +from utils.config import SearchType, SearchSettings, SearchModels +# Application Global State +model = SearchModels() search_settings = SearchSettings() app = FastAPI() @@ -29,16 +31,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if (t == SearchType.Notes or t == None) and search_settings.notes_search_enabled: # query notes - hits = asymmetric.query_notes( - user_query, - corpus_embeddings, - entries, - bi_encoder, - cross_encoder, - top_k) + hits = asymmetric.query_notes(user_query, model.notes_search) # collate and return results - return asymmetric.collate_results(hits, entries, results_count) + return asymmetric.collate_results(hits, model.notes_search.entries, results_count) if (t == SearchType.Music or t == None) and search_settings.music_search_enabled: # query music library @@ -90,9 +86,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): def regenerate(t: Optional[SearchType] = None): if (t == SearchType.Notes or t == None) and search_settings.notes_search_enabled: # Extract Entries, Generate Embeddings - global corpus_embeddings - global entries - entries, corpus_embeddings, _, _, _ = asymmetric.setup( + models.notes_search = asymmetric.setup( org_config['input-files'], org_config['input-filter'], pathlib.Path(org_config['compressed-jsonl']), @@ -146,7 +140,7 @@ if __name__ == '__main__': org_config = get_from_dict(args.config, 'content-type', 'org') if org_config and ('input-files' in org_config or 'input-filter' in org_config): search_settings.notes_search_enabled = True - entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = asymmetric.setup( + model.notes_search = asymmetric.setup( org_config['input-files'], org_config['input-filter'], pathlib.Path(org_config['compressed-jsonl']), diff --git a/src/search_type/asymmetric.py b/src/search_type/asymmetric.py index 17c5367e..e5ca2858 100644 --- a/src/search_type/asymmetric.py +++ b/src/search_type/asymmetric.py @@ -17,6 +17,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util # Internal Packages from utils.helpers import get_absolute_path, resolve_absolute_path from processor.org_mode.org_to_jsonl import org_to_jsonl +from utils.config import AsymmetricSearchModel def initialize_model(): @@ -64,7 +65,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v return corpus_embeddings -def query_notes(raw_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k=100): +def query_notes(raw_query: str, model: AsymmetricSearchModel): "Search all notes for entries that answer the query" # Separate natural query from explicit required, blocked words filters query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) @@ -72,20 +73,22 @@ def query_notes(raw_query, corpus_embeddings, entries, bi_encoder, cross_encoder blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) # Encode the query using the bi-encoder - question_embedding = bi_encoder.encode(query, convert_to_tensor=True) + question_embedding = model.bi_encoder.encode(query, convert_to_tensor=True) # Find relevant entries for the query - hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k) + hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k) hits = hits[0] # Get the hits for the first query # Filter results using explicit filters - hits = explicit_filter(hits, [entry[0] for entry in entries], required_words, blocked_words) + hits = explicit_filter(hits, + [entry[0] for entry in model.entries], + required_words,blocked_words) if hits is None or len(hits) == 0: return hits # Score all retrieved entries using the cross-encoder - cross_inp = [[query, entries[hit['corpus_id']][0]] for hit in hits] - cross_scores = cross_encoder.predict(cross_inp) + cross_inp = [[query, model.entries[hit['corpus_id']][0]] for hit in hits] + cross_scores = model.cross_encoder.predict(cross_inp) # Store cross-encoder scores in results dictionary for ranking for idx in range(len(cross_scores)): @@ -161,7 +164,7 @@ def setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate=Fa # Compute or Load Embeddings corpus_embeddings = compute_embeddings(entries, bi_encoder, embeddings, regenerate=regenerate, verbose=verbose) - return entries, corpus_embeddings, bi_encoder, cross_encoder, top_k + return AsymmetricSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k) if __name__ == '__main__': diff --git a/src/tests/test_main.py b/src/tests/test_main.py index 764d9b22..d8221568 100644 --- a/src/tests/test_main.py +++ b/src/tests/test_main.py @@ -6,7 +6,7 @@ import pytest from fastapi.testclient import TestClient # Internal Packages -from main import app +from main import app, search_settings, model from search_type import asymmetric @@ -55,18 +55,33 @@ def test_regenerate_with_valid_search_type(): # ---------------------------------------------------------------------------------------------------- -def test_asymmetric_setup(): +def test_notes_search(): # Arrange input_files = [Path('tests/data/main_readme.org'), Path('tests/data/interface_emacs_readme.org')] input_filter = None compressed_jsonl = Path('tests/data/.test.jsonl.gz') embeddings = Path('tests/data/.test_embeddings.pt') - regenerate = False - verbose = 1 # Act - entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = asymmetric.setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate, verbose) + # Regenerate embeddings during asymmetric setup + notes_model = asymmetric.setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate=True, verbose=0) # Assert - assert len(entries) == 10 - assert len(corpus_embeddings) == 10 + assert len(notes_model.entries) == 10 + assert len(notes_model.corpus_embeddings) == 10 + + # Arrange + model.notes_search = notes_model + search_settings.notes_search_enabled = True + user_query = "How to call semantic search from Emacs?" + + # Act + response = client.get(f"/search?q={user_query}&n=1&t=notes") + + # Assert + assert response.status_code == 200 + # assert actual_data contains "Semantic Search via Emacs" + search_result = response.json()[0]["Entry"] + assert "Semantic Search via Emacs" in search_result + + diff --git a/src/utils/config.py b/src/utils/config.py index eeecc6a6..dc09397b 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -18,3 +18,15 @@ class SearchSettings(): image_search_enabled: bool = False +class AsymmetricSearchModel(): + def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k): + self.entries = entries + self.corpus_embeddings = corpus_embeddings + self.bi_encoder = bi_encoder + self.cross_encoder = cross_encoder + self.top_k = top_k + + +@dataclass +class SearchModels(): + notes_search: AsymmetricSearchModel = None