mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 19:03:01 +01:00
Improve null and type checks
This commit is contained in:
parent
d5fb4196de
commit
2cd3e799d3
3 changed files with 26 additions and 21 deletions
|
@ -3,6 +3,7 @@ import sys
|
|||
import logging
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
import requests
|
||||
|
||||
# External Packages
|
||||
|
@ -78,16 +79,20 @@ def configure_search_types(config: FullConfig):
|
|||
core_search_types = {e.name: e.value for e in SearchType}
|
||||
# Extract configured plugin search types
|
||||
plugin_search_types = {}
|
||||
if config.content_type.plugins:
|
||||
if config.content_type and config.content_type.plugins:
|
||||
plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()}
|
||||
|
||||
# Dynamically generate search type enum by merging core search types with configured plugin search types
|
||||
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
||||
|
||||
|
||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: state.SearchType = None):
|
||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None):
|
||||
if config.content_type is None or config.search_type is None:
|
||||
logger.error("🚨 Content Type or Search Type not configured.")
|
||||
return
|
||||
|
||||
# Initialize Org Notes Search
|
||||
if (t == state.SearchType.Org or t == None) and config.content_type.org:
|
||||
if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric:
|
||||
logger.info("🦄 Setting up search for orgmode notes")
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
model.org_search = text_search.setup(
|
||||
|
@ -99,7 +104,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
|||
)
|
||||
|
||||
# Initialize Org Music Search
|
||||
if (t == state.SearchType.Music or t == None) and config.content_type.music:
|
||||
if (t == state.SearchType.Music or t == None) and config.content_type.music and config.search_type.asymmetric:
|
||||
logger.info("🎺 Setting up search for org-music")
|
||||
# Extract Entries, Generate Music Embeddings
|
||||
model.music_search = text_search.setup(
|
||||
|
@ -111,7 +116,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
|||
)
|
||||
|
||||
# Initialize Markdown Search
|
||||
if (t == state.SearchType.Markdown or t == None) and config.content_type.markdown:
|
||||
if (t == state.SearchType.Markdown or t == None) and config.content_type.markdown and config.search_type.asymmetric:
|
||||
logger.info("💎 Setting up search for markdown notes")
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
model.markdown_search = text_search.setup(
|
||||
|
@ -123,7 +128,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
|||
)
|
||||
|
||||
# Initialize Ledger Search
|
||||
if (t == state.SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||
if (t == state.SearchType.Ledger or t == None) and config.content_type.ledger and config.search_type.symmetric:
|
||||
logger.info("💸 Setting up search for ledger")
|
||||
# Extract Entries, Generate Ledger Embeddings
|
||||
model.ledger_search = text_search.setup(
|
||||
|
@ -135,7 +140,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
|||
)
|
||||
|
||||
# Initialize PDF Search
|
||||
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf:
|
||||
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric:
|
||||
logger.info("🖨️ Setting up search for pdf")
|
||||
# Extract Entries, Generate PDF Embeddings
|
||||
model.pdf_search = text_search.setup(
|
||||
|
@ -147,14 +152,14 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
|||
)
|
||||
|
||||
# Initialize Image Search
|
||||
if (t == state.SearchType.Image or t == None) and config.content_type.image:
|
||||
if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image:
|
||||
logger.info("🌄 Setting up search for images")
|
||||
# Extract Entries, Generate Image Embeddings
|
||||
model.image_search = image_search.setup(
|
||||
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
|
||||
)
|
||||
|
||||
if (t == state.SearchType.Github or t == None) and config.content_type.github:
|
||||
if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric:
|
||||
logger.info("🐙 Setting up search for github")
|
||||
# Extract Entries, Generate Github Embeddings
|
||||
model.github_search = text_search.setup(
|
||||
|
|
|
@ -134,7 +134,7 @@ async def search(
|
|||
user_query = q.strip()
|
||||
results_count = n
|
||||
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
||||
search_futures = defaultdict(list)
|
||||
search_futures: list[concurrent.futures.Future] = []
|
||||
|
||||
# return cached results, if available
|
||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
||||
|
@ -161,7 +161,7 @@ async def search(
|
|||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
if (t == SearchType.Org or t == None) and state.model.org_search:
|
||||
# query org-mode notes
|
||||
search_futures[t] += [
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
|
@ -175,7 +175,7 @@ async def search(
|
|||
|
||||
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
||||
# query markdown notes
|
||||
search_futures[t] += [
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
|
@ -189,7 +189,7 @@ async def search(
|
|||
|
||||
if (t == SearchType.Pdf or t == None) and state.model.pdf_search:
|
||||
# query pdf files
|
||||
search_futures[t] += [
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
|
@ -203,7 +203,7 @@ async def search(
|
|||
|
||||
if (t == SearchType.Ledger) and state.model.ledger_search:
|
||||
# query transactions
|
||||
search_futures[t] += [
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
|
@ -216,7 +216,7 @@ async def search(
|
|||
|
||||
if (t == SearchType.Music or t == None) and state.model.music_search:
|
||||
# query music library
|
||||
search_futures[t] += [
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
|
@ -230,7 +230,7 @@ async def search(
|
|||
|
||||
if (t == SearchType.Image) and state.model.image_search:
|
||||
# query images
|
||||
search_futures[t] += [
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
image_search.query,
|
||||
user_query,
|
||||
|
@ -242,7 +242,7 @@ async def search(
|
|||
|
||||
if (t is None or t in SearchType) and state.model.plugin_search:
|
||||
# query specified plugin type
|
||||
search_futures[t] += [
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
|
@ -257,7 +257,7 @@ async def search(
|
|||
|
||||
# Query across each requested content types in parallel
|
||||
with timer("Query took", logger):
|
||||
for search_future in concurrent.futures.as_completed(search_futures[t]):
|
||||
for search_future in concurrent.futures.as_completed(search_futures):
|
||||
if t == SearchType.Image:
|
||||
hits = await search_future.result()
|
||||
output_directory = constants.web_directory / "images"
|
||||
|
@ -288,7 +288,7 @@ async def search(
|
|||
state.previous_query = user_query
|
||||
|
||||
end_time = time.time()
|
||||
logger.debug(f"🔍 Search took: {end_time - start_time:.2f} seconds")
|
||||
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
||||
|
||||
return results
|
||||
|
||||
|
@ -297,7 +297,7 @@ async def search(
|
|||
def update(t: Optional[SearchType] = None, force: Optional[bool] = False, client: Optional[str] = None):
|
||||
try:
|
||||
state.search_index_lock.acquire()
|
||||
state.model = configure_search(state.model, state.config, regenerate=force, t=t)
|
||||
state.model = configure_search(state.model, state.config, regenerate=force or False, t=t)
|
||||
state.search_index_lock.release()
|
||||
except ValueError as e:
|
||||
logger.error(e)
|
||||
|
|
|
@ -181,7 +181,7 @@ def setup(
|
|||
previous_entries = (
|
||||
extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
|
||||
)
|
||||
entries_with_indices = text_to_jsonl(config).process(previous_entries)
|
||||
entries_with_indices = text_to_jsonl(config).process(previous_entries or [])
|
||||
|
||||
# Extract Updated Entries
|
||||
entries = extract_entries(config.compressed_jsonl)
|
||||
|
|
Loading…
Reference in a new issue