Reorder embeddings search arguments based on argument importance

This commit is contained in:
Debanjum Singh Solanky 2024-10-10 03:57:09 -07:00
parent 0eacc0b2b0
commit 6a8fd9bf33
4 changed files with 7 additions and 7 deletions

View file

@ -1413,11 +1413,11 @@ class EntryAdapters:
@staticmethod @staticmethod
def search_with_embeddings( def search_with_embeddings(
user: KhojUser, raw_query: str,
embeddings: Tensor, embeddings: Tensor,
user: KhojUser,
max_results: int = 10, max_results: int = 10,
file_type_filter: str = None, file_type_filter: str = None,
raw_query: str = None,
max_distance: float = math.inf, max_distance: float = math.inf,
agent: Agent = None, agent: Agent = None,
): ):

View file

@ -160,8 +160,8 @@ async def execute_search(
search_futures += [ search_futures += [
executor.submit( executor.submit(
text_search.query, text_search.query,
user,
user_query, user_query,
user,
t, t,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
max_distance=max_distance, max_distance=max_distance,

View file

@ -97,8 +97,8 @@ def load_embeddings(
async def query( async def query(
user: KhojUser,
raw_query: str, raw_query: str,
user: KhojUser,
type: SearchType = SearchType.All, type: SearchType = SearchType.All,
question_embedding: Union[torch.Tensor, None] = None, question_embedding: Union[torch.Tensor, None] = None,
max_distance: float = None, max_distance: float = None,
@ -125,12 +125,12 @@ async def query(
top_k = 10 top_k = 10
with timer("Search Time", logger, state.device): with timer("Search Time", logger, state.device):
hits = EntryAdapters.search_with_embeddings( hits = EntryAdapters.search_with_embeddings(
user=user, raw_query=raw_query,
embeddings=question_embedding, embeddings=question_embedding,
max_results=top_k, max_results=top_k,
file_type_filter=file_type, file_type_filter=file_type,
raw_query=raw_query,
max_distance=max_distance, max_distance=max_distance,
user=user,
agent=agent, agent=agent,
).all() ).all()
hits = await sync_to_async(list)(hits) # type: ignore[call-arg] hits = await sync_to_async(list)(hits) # type: ignore[call-arg]

View file

@ -164,7 +164,7 @@ async def test_text_search(search_config: SearchConfig):
query = "Load Khoj on Emacs?" query = "Load Khoj on Emacs?"
# Act # Act
hits = await text_search.query(default_user, query) hits = await text_search.query(query, default_user)
results = text_search.collate_results(hits) results = text_search.collate_results(hits)
results = sorted(results, key=lambda x: float(x.score))[:1] results = sorted(results, key=lambda x: float(x.score))[:1]