mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Get any configured asymmetric search model to encode query for search
- Set image_search.query to async to use it with multi-threading This is same as text_search.query being set to an async method - Exit search early if no search_model is defined in state.model
This commit is contained in:
parent
8eae7c898c
commit
b1767f93d6
2 changed files with 19 additions and 9 deletions
|
@ -1,5 +1,4 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
from collections import defaultdict
|
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
@ -21,6 +20,7 @@ from khoj.search_type import image_search, text_search
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
|
from khoj.utils.config import TextSearchModel
|
||||||
from khoj.utils.helpers import log_telemetry, timer
|
from khoj.utils.helpers import log_telemetry, timer
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
ContentConfig,
|
ContentConfig,
|
||||||
|
@ -144,10 +144,14 @@ async def search(
|
||||||
):
|
):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Run validation checks
|
||||||
results: List[SearchResponse] = []
|
results: List[SearchResponse] = []
|
||||||
if q is None or q == "":
|
if q is None or q == "":
|
||||||
logger.warn(f"No query param (q) passed in API call to initiate search")
|
logger.warn(f"No query param (q) passed in API call to initiate search")
|
||||||
return results
|
return results
|
||||||
|
if not state.model or not any(state.model.__dict__.values()):
|
||||||
|
logger.warn(f"No search models loaded. Configure a search model before initiating search")
|
||||||
|
return results
|
||||||
|
|
||||||
# initialize variables
|
# initialize variables
|
||||||
user_query = q.strip()
|
user_query = q.strip()
|
||||||
|
@ -168,14 +172,20 @@ async def search(
|
||||||
|
|
||||||
encoded_asymmetric_query = None
|
encoded_asymmetric_query = None
|
||||||
if t == SearchType.All or (t != SearchType.Ledger and t != SearchType.Image):
|
if t == SearchType.All or (t != SearchType.Ledger and t != SearchType.Image):
|
||||||
with timer("Encoding query took", logger=logger):
|
text_search_models: List[TextSearchModel] = [
|
||||||
encoded_asymmetric_query = util.normalize_embeddings(
|
model
|
||||||
state.model.org_search.bi_encoder.encode(
|
for model_name, model in state.model.__dict__.items()
|
||||||
[defiltered_query],
|
if isinstance(model, TextSearchModel) and model_name != "ledger_search"
|
||||||
convert_to_tensor=True,
|
]
|
||||||
device=state.device,
|
if text_search_models:
|
||||||
|
with timer("Encoding query took", logger=logger):
|
||||||
|
encoded_asymmetric_query = util.normalize_embeddings(
|
||||||
|
text_search_models[0].bi_encoder.encode(
|
||||||
|
[defiltered_query],
|
||||||
|
convert_to_tensor=True,
|
||||||
|
device=state.device,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search:
|
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search:
|
||||||
|
|
|
@ -143,7 +143,7 @@ def extract_metadata(image_name):
|
||||||
return image_processed_metadata
|
return image_processed_metadata
|
||||||
|
|
||||||
|
|
||||||
def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
|
async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
|
||||||
# Set query to image content if query is of form file:/path/to/file.png
|
# Set query to image content if query is of form file:/path/to/file.png
|
||||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||||
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
||||||
|
|
Loading…
Reference in a new issue