Get any configured asymmetric search model to encode query for search

- Set image_search.query to async to use it with multi-threading
  This is same as text_search.query being set to an async method
- Exit search early if no search_model is defined in state.model
This commit is contained in:
Debanjum Singh Solanky 2023-06-28 19:53:20 -07:00
parent 8eae7c898c
commit b1767f93d6
2 changed files with 19 additions and 9 deletions

View file

@ -1,5 +1,4 @@
# Standard Packages
from collections import defaultdict
import concurrent.futures
import math
import time
@ -21,6 +20,7 @@ from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils.config import TextSearchModel
from khoj.utils.helpers import log_telemetry, timer
from khoj.utils.rawconfig import (
ContentConfig,
@ -144,10 +144,14 @@ async def search(
):
start_time = time.time()
# Run validation checks
results: List[SearchResponse] = []
if q is None or q == "":
logger.warn(f"No query param (q) passed in API call to initiate search")
return results
if not state.model or not any(state.model.__dict__.values()):
logger.warn(f"No search models loaded. Configure a search model before initiating search")
return results
# initialize variables
user_query = q.strip()
@ -168,14 +172,20 @@ async def search(
encoded_asymmetric_query = None
if t == SearchType.All or (t != SearchType.Ledger and t != SearchType.Image):
with timer("Encoding query took", logger=logger):
encoded_asymmetric_query = util.normalize_embeddings(
state.model.org_search.bi_encoder.encode(
[defiltered_query],
convert_to_tensor=True,
device=state.device,
text_search_models: List[TextSearchModel] = [
model
for model_name, model in state.model.__dict__.items()
if isinstance(model, TextSearchModel) and model_name != "ledger_search"
]
if text_search_models:
with timer("Encoding query took", logger=logger):
encoded_asymmetric_query = util.normalize_embeddings(
text_search_models[0].bi_encoder.encode(
[defiltered_query],
convert_to_tensor=True,
device=state.device,
)
)
)
with concurrent.futures.ThreadPoolExecutor() as executor:
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search:

View file

@ -143,7 +143,7 @@ def extract_metadata(image_name):
return image_processed_metadata
def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
# Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)