From 6a8fd9bf3370dfe0e945294dc6b4db545ea36469 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 10 Oct 2024 03:57:09 -0700 Subject: [PATCH] Reorder embeddings search arguments based on argument importance --- src/khoj/database/adapters/__init__.py | 4 ++-- src/khoj/routers/api.py | 2 +- src/khoj/search_type/text_search.py | 6 +++--- tests/test_text_search.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 9687ec01..027490be 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1413,11 +1413,11 @@ class EntryAdapters: @staticmethod def search_with_embeddings( - user: KhojUser, + raw_query: str, embeddings: Tensor, + user: KhojUser, max_results: int = 10, file_type_filter: str = None, - raw_query: str = None, max_distance: float = math.inf, agent: Agent = None, ): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 46fdfa43..6a30e194 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -160,8 +160,8 @@ async def execute_search( search_futures += [ executor.submit( text_search.query, - user, user_query, + user, t, question_embedding=encoded_asymmetric_query, max_distance=max_distance, diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 52e23f29..ae873c33 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -97,8 +97,8 @@ def load_embeddings( async def query( - user: KhojUser, raw_query: str, + user: KhojUser, type: SearchType = SearchType.All, question_embedding: Union[torch.Tensor, None] = None, max_distance: float = None, @@ -125,12 +125,12 @@ async def query( top_k = 10 with timer("Search Time", logger, state.device): hits = EntryAdapters.search_with_embeddings( - user=user, + raw_query=raw_query, embeddings=question_embedding, max_results=top_k, file_type_filter=file_type, - raw_query=raw_query, max_distance=max_distance, + user=user, agent=agent, ).all() hits = await sync_to_async(list)(hits) # type: ignore[call-arg] diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 4529aa53..712f4aba 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -164,7 +164,7 @@ async def test_text_search(search_config: SearchConfig): query = "Load Khoj on Emacs?" # Act - hits = await text_search.query(default_user, query) + hits = await text_search.query(query, default_user) results = text_search.collate_results(hits) results = sorted(results, key=lambda x: float(x.score))[:1]