mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Pass default value matching argument types expected by text_search methods
This commit is contained in:
parent
0144e610d6
commit
1192e49307
2 changed files with 22 additions and 17 deletions
|
@ -24,6 +24,7 @@ from khoj.search_filter.word_filter import WordFilter
|
|||
from khoj.utils.helpers import log_telemetry, timer
|
||||
from khoj.utils.rawconfig import (
|
||||
FullConfig,
|
||||
ProcessorConfig,
|
||||
SearchResponse,
|
||||
TextContentConfig,
|
||||
ConversationProcessorConfig,
|
||||
|
@ -101,7 +102,10 @@ async def set_content_config_data(content_type: str, updated_config: TextContent
|
|||
|
||||
@api.post("/config/data/processor/conversation", status_code=200)
|
||||
async def set_processor_conversation_config_data(updated_config: ConversationProcessorConfig):
|
||||
state.config.processor.conversation = updated_config
|
||||
if state.config.processor is None:
|
||||
state.config.processor = ProcessorConfig(conversation=updated_config)
|
||||
else:
|
||||
state.config.processor.conversation = updated_config
|
||||
try:
|
||||
save_config_to_file_updated_state()
|
||||
return {"status": "ok"}
|
||||
|
@ -139,6 +143,7 @@ def search(
|
|||
return state.query_cache[query_cache_key]
|
||||
|
||||
# Encode query with filter terms removed
|
||||
defiltered_query = user_query
|
||||
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
||||
defiltered_query = filter.defilter(user_query)
|
||||
|
||||
|
@ -162,9 +167,9 @@ def search(
|
|||
user_query,
|
||||
state.model.org_search,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -176,9 +181,9 @@ def search(
|
|||
user_query,
|
||||
state.model.markdown_search,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -190,9 +195,9 @@ def search(
|
|||
user_query,
|
||||
state.model.pdf_search,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -203,9 +208,9 @@ def search(
|
|||
text_search.query,
|
||||
user_query,
|
||||
state.model.ledger_search,
|
||||
rank_results=r,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -217,9 +222,9 @@ def search(
|
|||
user_query,
|
||||
state.model.music_search,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -237,16 +242,16 @@ def search(
|
|||
|
||||
if (t is None or t in SearchType) and state.model.plugin_search:
|
||||
# query specified plugin type
|
||||
search_future[t] += [
|
||||
search_futures[t] += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
# Get plugin search model for specified search type, or the first one if none specified
|
||||
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -262,12 +267,12 @@ def search(
|
|||
image_names=state.model.image_search.image_names,
|
||||
output_directory=output_directory,
|
||||
image_files_url="/static/images",
|
||||
count=results_count,
|
||||
count=results_count or 5,
|
||||
)
|
||||
else:
|
||||
hits, entries = search_future.result()
|
||||
# Collate results
|
||||
results += text_search.collate_results(hits, entries, results_count)
|
||||
results += text_search.collate_results(hits, entries, results_count or 5)
|
||||
|
||||
# Sort results across all content types
|
||||
results.sort(key=lambda x: float(x.score), reverse=True)
|
||||
|
|
|
@ -105,7 +105,7 @@ def compute_embeddings(
|
|||
def query(
|
||||
raw_query: str,
|
||||
model: TextSearchModel,
|
||||
question_embedding: torch.Tensor = None,
|
||||
question_embedding: torch.Tensor | None = None,
|
||||
rank_results: bool = False,
|
||||
score_threshold: float = -math.inf,
|
||||
dedupe: bool = True,
|
||||
|
|
Loading…
Add table
Reference in a new issue