mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 19:03:01 +01:00
Improve null and type checks
This commit is contained in:
parent
d5fb4196de
commit
2cd3e799d3
3 changed files with 26 additions and 21 deletions
|
@ -3,6 +3,7 @@ import sys
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
|
@ -78,16 +79,20 @@ def configure_search_types(config: FullConfig):
|
||||||
core_search_types = {e.name: e.value for e in SearchType}
|
core_search_types = {e.name: e.value for e in SearchType}
|
||||||
# Extract configured plugin search types
|
# Extract configured plugin search types
|
||||||
plugin_search_types = {}
|
plugin_search_types = {}
|
||||||
if config.content_type.plugins:
|
if config.content_type and config.content_type.plugins:
|
||||||
plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()}
|
plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()}
|
||||||
|
|
||||||
# Dynamically generate search type enum by merging core search types with configured plugin search types
|
# Dynamically generate search type enum by merging core search types with configured plugin search types
|
||||||
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
||||||
|
|
||||||
|
|
||||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: state.SearchType = None):
|
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None):
|
||||||
|
if config.content_type is None or config.search_type is None:
|
||||||
|
logger.error("🚨 Content Type or Search Type not configured.")
|
||||||
|
return
|
||||||
|
|
||||||
# Initialize Org Notes Search
|
# Initialize Org Notes Search
|
||||||
if (t == state.SearchType.Org or t == None) and config.content_type.org:
|
if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric:
|
||||||
logger.info("🦄 Setting up search for orgmode notes")
|
logger.info("🦄 Setting up search for orgmode notes")
|
||||||
# Extract Entries, Generate Notes Embeddings
|
# Extract Entries, Generate Notes Embeddings
|
||||||
model.org_search = text_search.setup(
|
model.org_search = text_search.setup(
|
||||||
|
@ -99,7 +104,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Org Music Search
|
# Initialize Org Music Search
|
||||||
if (t == state.SearchType.Music or t == None) and config.content_type.music:
|
if (t == state.SearchType.Music or t == None) and config.content_type.music and config.search_type.asymmetric:
|
||||||
logger.info("🎺 Setting up search for org-music")
|
logger.info("🎺 Setting up search for org-music")
|
||||||
# Extract Entries, Generate Music Embeddings
|
# Extract Entries, Generate Music Embeddings
|
||||||
model.music_search = text_search.setup(
|
model.music_search = text_search.setup(
|
||||||
|
@ -111,7 +116,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Markdown Search
|
# Initialize Markdown Search
|
||||||
if (t == state.SearchType.Markdown or t == None) and config.content_type.markdown:
|
if (t == state.SearchType.Markdown or t == None) and config.content_type.markdown and config.search_type.asymmetric:
|
||||||
logger.info("💎 Setting up search for markdown notes")
|
logger.info("💎 Setting up search for markdown notes")
|
||||||
# Extract Entries, Generate Markdown Embeddings
|
# Extract Entries, Generate Markdown Embeddings
|
||||||
model.markdown_search = text_search.setup(
|
model.markdown_search = text_search.setup(
|
||||||
|
@ -123,7 +128,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Ledger Search
|
# Initialize Ledger Search
|
||||||
if (t == state.SearchType.Ledger or t == None) and config.content_type.ledger:
|
if (t == state.SearchType.Ledger or t == None) and config.content_type.ledger and config.search_type.symmetric:
|
||||||
logger.info("💸 Setting up search for ledger")
|
logger.info("💸 Setting up search for ledger")
|
||||||
# Extract Entries, Generate Ledger Embeddings
|
# Extract Entries, Generate Ledger Embeddings
|
||||||
model.ledger_search = text_search.setup(
|
model.ledger_search = text_search.setup(
|
||||||
|
@ -135,7 +140,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize PDF Search
|
# Initialize PDF Search
|
||||||
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf:
|
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric:
|
||||||
logger.info("🖨️ Setting up search for pdf")
|
logger.info("🖨️ Setting up search for pdf")
|
||||||
# Extract Entries, Generate PDF Embeddings
|
# Extract Entries, Generate PDF Embeddings
|
||||||
model.pdf_search = text_search.setup(
|
model.pdf_search = text_search.setup(
|
||||||
|
@ -147,14 +152,14 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Image Search
|
# Initialize Image Search
|
||||||
if (t == state.SearchType.Image or t == None) and config.content_type.image:
|
if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image:
|
||||||
logger.info("🌄 Setting up search for images")
|
logger.info("🌄 Setting up search for images")
|
||||||
# Extract Entries, Generate Image Embeddings
|
# Extract Entries, Generate Image Embeddings
|
||||||
model.image_search = image_search.setup(
|
model.image_search = image_search.setup(
|
||||||
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
|
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
|
||||||
)
|
)
|
||||||
|
|
||||||
if (t == state.SearchType.Github or t == None) and config.content_type.github:
|
if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric:
|
||||||
logger.info("🐙 Setting up search for github")
|
logger.info("🐙 Setting up search for github")
|
||||||
# Extract Entries, Generate Github Embeddings
|
# Extract Entries, Generate Github Embeddings
|
||||||
model.github_search = text_search.setup(
|
model.github_search = text_search.setup(
|
||||||
|
|
|
@ -134,7 +134,7 @@ async def search(
|
||||||
user_query = q.strip()
|
user_query = q.strip()
|
||||||
results_count = n
|
results_count = n
|
||||||
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
||||||
search_futures = defaultdict(list)
|
search_futures: list[concurrent.futures.Future] = []
|
||||||
|
|
||||||
# return cached results, if available
|
# return cached results, if available
|
||||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
||||||
|
@ -161,7 +161,7 @@ async def search(
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
if (t == SearchType.Org or t == None) and state.model.org_search:
|
if (t == SearchType.Org or t == None) and state.model.org_search:
|
||||||
# query org-mode notes
|
# query org-mode notes
|
||||||
search_futures[t] += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
|
@ -175,7 +175,7 @@ async def search(
|
||||||
|
|
||||||
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
||||||
# query markdown notes
|
# query markdown notes
|
||||||
search_futures[t] += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
|
@ -189,7 +189,7 @@ async def search(
|
||||||
|
|
||||||
if (t == SearchType.Pdf or t == None) and state.model.pdf_search:
|
if (t == SearchType.Pdf or t == None) and state.model.pdf_search:
|
||||||
# query pdf files
|
# query pdf files
|
||||||
search_futures[t] += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
|
@ -203,7 +203,7 @@ async def search(
|
||||||
|
|
||||||
if (t == SearchType.Ledger) and state.model.ledger_search:
|
if (t == SearchType.Ledger) and state.model.ledger_search:
|
||||||
# query transactions
|
# query transactions
|
||||||
search_futures[t] += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
|
@ -216,7 +216,7 @@ async def search(
|
||||||
|
|
||||||
if (t == SearchType.Music or t == None) and state.model.music_search:
|
if (t == SearchType.Music or t == None) and state.model.music_search:
|
||||||
# query music library
|
# query music library
|
||||||
search_futures[t] += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
|
@ -230,7 +230,7 @@ async def search(
|
||||||
|
|
||||||
if (t == SearchType.Image) and state.model.image_search:
|
if (t == SearchType.Image) and state.model.image_search:
|
||||||
# query images
|
# query images
|
||||||
search_futures[t] += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
image_search.query,
|
image_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
|
@ -242,7 +242,7 @@ async def search(
|
||||||
|
|
||||||
if (t is None or t in SearchType) and state.model.plugin_search:
|
if (t is None or t in SearchType) and state.model.plugin_search:
|
||||||
# query specified plugin type
|
# query specified plugin type
|
||||||
search_futures[t] += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
|
@ -257,7 +257,7 @@ async def search(
|
||||||
|
|
||||||
# Query across each requested content types in parallel
|
# Query across each requested content types in parallel
|
||||||
with timer("Query took", logger):
|
with timer("Query took", logger):
|
||||||
for search_future in concurrent.futures.as_completed(search_futures[t]):
|
for search_future in concurrent.futures.as_completed(search_futures):
|
||||||
if t == SearchType.Image:
|
if t == SearchType.Image:
|
||||||
hits = await search_future.result()
|
hits = await search_future.result()
|
||||||
output_directory = constants.web_directory / "images"
|
output_directory = constants.web_directory / "images"
|
||||||
|
@ -288,7 +288,7 @@ async def search(
|
||||||
state.previous_query = user_query
|
state.previous_query = user_query
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.debug(f"🔍 Search took: {end_time - start_time:.2f} seconds")
|
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -297,7 +297,7 @@ async def search(
|
||||||
def update(t: Optional[SearchType] = None, force: Optional[bool] = False, client: Optional[str] = None):
|
def update(t: Optional[SearchType] = None, force: Optional[bool] = False, client: Optional[str] = None):
|
||||||
try:
|
try:
|
||||||
state.search_index_lock.acquire()
|
state.search_index_lock.acquire()
|
||||||
state.model = configure_search(state.model, state.config, regenerate=force, t=t)
|
state.model = configure_search(state.model, state.config, regenerate=force or False, t=t)
|
||||||
state.search_index_lock.release()
|
state.search_index_lock.release()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
|
|
|
@ -181,7 +181,7 @@ def setup(
|
||||||
previous_entries = (
|
previous_entries = (
|
||||||
extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
|
extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
|
||||||
)
|
)
|
||||||
entries_with_indices = text_to_jsonl(config).process(previous_entries)
|
entries_with_indices = text_to_jsonl(config).process(previous_entries or [])
|
||||||
|
|
||||||
# Extract Updated Entries
|
# Extract Updated Entries
|
||||||
entries = extract_entries(config.compressed_jsonl)
|
entries = extract_entries(config.compressed_jsonl)
|
||||||
|
|
Loading…
Reference in a new issue