Pass default value matching argument types expected by text_search methods

This commit is contained in:
Debanjum Singh Solanky 2023-06-20 19:51:33 -07:00
parent 0144e610d6
commit 1192e49307
2 changed files with 22 additions and 17 deletions

View file

@ -24,6 +24,7 @@ from khoj.search_filter.word_filter import WordFilter
from khoj.utils.helpers import log_telemetry, timer
from khoj.utils.rawconfig import (
FullConfig,
ProcessorConfig,
SearchResponse,
TextContentConfig,
ConversationProcessorConfig,
@ -101,7 +102,10 @@ async def set_content_config_data(content_type: str, updated_config: TextContent
@api.post("/config/data/processor/conversation", status_code=200)
async def set_processor_conversation_config_data(updated_config: ConversationProcessorConfig):
state.config.processor.conversation = updated_config
if state.config.processor is None:
state.config.processor = ProcessorConfig(conversation=updated_config)
else:
state.config.processor.conversation = updated_config
try:
save_config_to_file_updated_state()
return {"status": "ok"}
@ -139,6 +143,7 @@ def search(
return state.query_cache[query_cache_key]
# Encode query with filter terms removed
defiltered_query = user_query
for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(user_query)
@ -162,9 +167,9 @@ def search(
user_query,
state.model.org_search,
question_embedding=encoded_asymmetric_query,
rank_results=r,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe,
dedupe=dedupe or True,
)
]
@ -176,9 +181,9 @@ def search(
user_query,
state.model.markdown_search,
question_embedding=encoded_asymmetric_query,
rank_results=r,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe,
dedupe=dedupe or True,
)
]
@ -190,9 +195,9 @@ def search(
user_query,
state.model.pdf_search,
question_embedding=encoded_asymmetric_query,
rank_results=r,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe,
dedupe=dedupe or True,
)
]
@ -203,9 +208,9 @@ def search(
text_search.query,
user_query,
state.model.ledger_search,
rank_results=r,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe,
dedupe=dedupe or True,
)
]
@ -217,9 +222,9 @@ def search(
user_query,
state.model.music_search,
question_embedding=encoded_asymmetric_query,
rank_results=r,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe,
dedupe=dedupe or True,
)
]
@ -237,16 +242,16 @@ def search(
if (t is None or t in SearchType) and state.model.plugin_search:
# query specified plugin type
search_future[t] += [
search_futures[t] += [
executor.submit(
text_search.query,
user_query,
# Get plugin search model for specified search type, or the first one if none specified
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
question_embedding=encoded_asymmetric_query,
rank_results=r,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe,
dedupe=dedupe or True,
)
]
@ -262,12 +267,12 @@ def search(
image_names=state.model.image_search.image_names,
output_directory=output_directory,
image_files_url="/static/images",
count=results_count,
count=results_count or 5,
)
else:
hits, entries = search_future.result()
# Collate results
results += text_search.collate_results(hits, entries, results_count)
results += text_search.collate_results(hits, entries, results_count or 5)
# Sort results across all content types
results.sort(key=lambda x: float(x.score), reverse=True)

View file

@ -105,7 +105,7 @@ def compute_embeddings(
def query(
raw_query: str,
model: TextSearchModel,
question_embedding: torch.Tensor = None,
question_embedding: torch.Tensor | None = None,
rank_results: bool = False,
score_threshold: float = -math.inf,
dedupe: bool = True,