mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Get any configured asymmetric search model to encode query for search
- Set image_search.query to async to use it with multi-threading This is same as text_search.query being set to an async method - Exit search early if no search_model is defined in state.model
This commit is contained in:
parent
8eae7c898c
commit
b1767f93d6
2 changed files with 19 additions and 9 deletions
|
@ -1,5 +1,4 @@
|
|||
# Standard Packages
|
||||
from collections import defaultdict
|
||||
import concurrent.futures
|
||||
import math
|
||||
import time
|
||||
|
@ -21,6 +20,7 @@ from khoj.search_type import image_search, text_search
|
|||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils.config import TextSearchModel
|
||||
from khoj.utils.helpers import log_telemetry, timer
|
||||
from khoj.utils.rawconfig import (
|
||||
ContentConfig,
|
||||
|
@ -144,10 +144,14 @@ async def search(
|
|||
):
|
||||
start_time = time.time()
|
||||
|
||||
# Run validation checks
|
||||
results: List[SearchResponse] = []
|
||||
if q is None or q == "":
|
||||
logger.warn(f"No query param (q) passed in API call to initiate search")
|
||||
return results
|
||||
if not state.model or not any(state.model.__dict__.values()):
|
||||
logger.warn(f"No search models loaded. Configure a search model before initiating search")
|
||||
return results
|
||||
|
||||
# initialize variables
|
||||
user_query = q.strip()
|
||||
|
@ -168,14 +172,20 @@ async def search(
|
|||
|
||||
encoded_asymmetric_query = None
|
||||
if t == SearchType.All or (t != SearchType.Ledger and t != SearchType.Image):
|
||||
with timer("Encoding query took", logger=logger):
|
||||
encoded_asymmetric_query = util.normalize_embeddings(
|
||||
state.model.org_search.bi_encoder.encode(
|
||||
[defiltered_query],
|
||||
convert_to_tensor=True,
|
||||
device=state.device,
|
||||
text_search_models: List[TextSearchModel] = [
|
||||
model
|
||||
for model_name, model in state.model.__dict__.items()
|
||||
if isinstance(model, TextSearchModel) and model_name != "ledger_search"
|
||||
]
|
||||
if text_search_models:
|
||||
with timer("Encoding query took", logger=logger):
|
||||
encoded_asymmetric_query = util.normalize_embeddings(
|
||||
text_search_models[0].bi_encoder.encode(
|
||||
[defiltered_query],
|
||||
convert_to_tensor=True,
|
||||
device=state.device,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search:
|
||||
|
|
|
@ -143,7 +143,7 @@ def extract_metadata(image_name):
|
|||
return image_processed_metadata
|
||||
|
||||
|
||||
def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
|
||||
async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
|
||||
# Set query to image content if query is of form file:/path/to/file.png
|
||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
||||
|
|
Loading…
Reference in a new issue