From 6d94d6e75a28a7becf2c947f81a8255db0cf67dc Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 20 Jun 2023 01:17:21 -0700 Subject: [PATCH] Encode the asymmetric, symmetric search queries in parallel for speed Use timer to measure time to encode queries and total search time --- src/khoj/routers/api.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 35216343..a1aef1a1 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -2,6 +2,7 @@ from collections import defaultdict import concurrent.futures import math +import time import yaml import logging from datetime import datetime @@ -118,6 +119,8 @@ def search( dedupe: Optional[bool] = True, client: Optional[str] = None, ): + start_time = time.time() + results: List[SearchResponse] = [] if q is None or q == "": logger.warn(f"No query param (q) passed in API call to initiate search") @@ -139,15 +142,26 @@ def search( for filter in [DateFilter(), WordFilter(), FileFilter()]: defiltered_query = filter.defilter(user_query) - encoded_asymmetric_query = state.model.org_search.bi_encoder.encode( - [defiltered_query], convert_to_tensor=True, device=state.device - ) - encoded_asymmetric_query = util.normalize_embeddings(encoded_asymmetric_query) + with concurrent.futures.ThreadPoolExecutor() as executor: + with timer("Encoding query for asymmetric search took", logger=logger): + encode_asymmetric_futures = executor.submit( + state.model.org_search.bi_encoder.encode, + [defiltered_query], + convert_to_tensor=True, + device=state.device, + ) - encoded_symmetric_query = state.model.org_search.bi_encoder.encode( - [defiltered_query], convert_to_tensor=True, device=state.device - ) - encoded_symmetric_query = util.normalize_embeddings(encoded_symmetric_query) + with timer("Encoding query for symmetric search took", logger=logger): + encode_symmetric_futures = executor.submit( + state.model.org_search.bi_encoder.encode, + [defiltered_query], + convert_to_tensor=True, + device=state.device, + ) + + with timer("Normalizing query embeddings took", logger=logger): + encoded_asymmetric_query = util.normalize_embeddings(encode_asymmetric_futures.result()) + encoded_symmetric_query = util.normalize_embeddings(encode_symmetric_futures.result()) with concurrent.futures.ThreadPoolExecutor() as executor: if (t == SearchType.Org or t == None) and state.model.org_search: @@ -279,6 +293,9 @@ def search( ] state.previous_query = user_query + end_time = time.time() + logger.debug(f"🔍 Search took {end_time - start_time:.2f} seconds") + return results