Get any configured asymmetric search model to encode query for search

- Set image_search.query to async to use it with multi-threading
  This is same as text_search.query being set to an async method
- Exit search early if no search_model is defined in state.model
This commit is contained in:
Debanjum Singh Solanky 2023-06-28 19:53:20 -07:00
parent 8eae7c898c
commit b1767f93d6
2 changed files with 19 additions and 9 deletions

View file

@ -1,5 +1,4 @@
# Standard Packages # Standard Packages
from collections import defaultdict
import concurrent.futures import concurrent.futures
import math import math
import time import time
@ -21,6 +20,7 @@ from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.word_filter import WordFilter
from khoj.utils.config import TextSearchModel
from khoj.utils.helpers import log_telemetry, timer from khoj.utils.helpers import log_telemetry, timer
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import (
ContentConfig, ContentConfig,
@ -144,10 +144,14 @@ async def search(
): ):
start_time = time.time() start_time = time.time()
# Run validation checks
results: List[SearchResponse] = [] results: List[SearchResponse] = []
if q is None or q == "": if q is None or q == "":
logger.warn(f"No query param (q) passed in API call to initiate search") logger.warn(f"No query param (q) passed in API call to initiate search")
return results return results
if not state.model or not any(state.model.__dict__.values()):
logger.warn(f"No search models loaded. Configure a search model before initiating search")
return results
# initialize variables # initialize variables
user_query = q.strip() user_query = q.strip()
@ -168,14 +172,20 @@ async def search(
encoded_asymmetric_query = None encoded_asymmetric_query = None
if t == SearchType.All or (t != SearchType.Ledger and t != SearchType.Image): if t == SearchType.All or (t != SearchType.Ledger and t != SearchType.Image):
with timer("Encoding query took", logger=logger): text_search_models: List[TextSearchModel] = [
encoded_asymmetric_query = util.normalize_embeddings( model
state.model.org_search.bi_encoder.encode( for model_name, model in state.model.__dict__.items()
[defiltered_query], if isinstance(model, TextSearchModel) and model_name != "ledger_search"
convert_to_tensor=True, ]
device=state.device, if text_search_models:
with timer("Encoding query took", logger=logger):
encoded_asymmetric_query = util.normalize_embeddings(
text_search_models[0].bi_encoder.encode(
[defiltered_query],
convert_to_tensor=True,
device=state.device,
)
) )
)
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search: if (t == SearchType.Org or t == SearchType.All) and state.model.org_search:

View file

@ -143,7 +143,7 @@ def extract_metadata(image_name):
return image_processed_metadata return image_processed_metadata
def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf): async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
# Set query to image content if query is of form file:/path/to/file.png # Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file(): if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True) query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)