diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 33fb8250..e42337df 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -2,6 +2,7 @@ import sys import logging import json +from enum import Enum # External Packages import schedule @@ -14,7 +15,7 @@ from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.search_type import image_search, text_search from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from khoj.utils import state -from khoj.utils.helpers import LRU, resolve_absolute_path +from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts from khoj.utils.rawconfig import FullConfig, ProcessorConfig from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.word_filter import WordFilter @@ -40,8 +41,9 @@ def configure_server(args, required=False): # Initialize Processor from Config state.processor_config = configure_processor(args.config.processor) - # Initialize the search model from Config + # Initialize the search type and model from Config state.search_index_lock.acquire() + state.SearchType = configure_search_types(state.config) state.model = configure_search(state.model, state.config, args.regenerate) state.search_index_lock.release() @@ -54,9 +56,19 @@ def update_search_index(): logger.info("Search Index updated via Scheduler") -def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None): +def configure_search_types(config: FullConfig): + # Extract core search types + core_search_types = {e.name: e.value for e in SearchType} + # Extract configured plugin search types + plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()} + + # Dynamically generate search type enum by merging core search types with configured plugin search types + return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types)) + + +def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: state.SearchType = None): # Initialize Org Notes Search - if (t == SearchType.Org or t == None) and config.content_type.org: + if (t == state.SearchType.Org or t == None) and config.content_type.org: # Extract Entries, Generate Notes Embeddings model.orgmode_search = text_search.setup( OrgToJsonl, @@ -67,7 +79,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, ) # Initialize Org Music Search - if (t == SearchType.Music or t == None) and config.content_type.music: + if (t == state.SearchType.Music or t == None) and config.content_type.music: # Extract Entries, Generate Music Embeddings model.music_search = text_search.setup( OrgToJsonl, @@ -78,7 +90,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, ) # Initialize Markdown Search - if (t == SearchType.Markdown or t == None) and config.content_type.markdown: + if (t == state.SearchType.Markdown or t == None) and config.content_type.markdown: # Extract Entries, Generate Markdown Embeddings model.markdown_search = text_search.setup( MarkdownToJsonl, @@ -89,7 +101,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, ) # Initialize Ledger Search - if (t == SearchType.Ledger or t == None) and config.content_type.ledger: + if (t == state.SearchType.Ledger or t == None) and config.content_type.ledger: # Extract Entries, Generate Ledger Embeddings model.ledger_search = text_search.setup( BeancountToJsonl, @@ -100,7 +112,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, ) # Initialize Image Search - if (t == SearchType.Image or t == None) and config.content_type.image: + if (t == state.SearchType.Image or t == None) and config.content_type.image: # Extract Entries, Generate Image Embeddings model.image_search = image_search.setup( config.content_type.image, search_config=config.search_type.image, regenerate=regenerate diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 3c946d9f..2666648b 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -12,10 +12,9 @@ from khoj.configure import configure_processor, configure_search from khoj.search_type import image_search, text_search from khoj.utils.helpers import timer from khoj.utils.rawconfig import FullConfig, SearchResponse -from khoj.utils.config import SearchType +from khoj.utils.state import SearchType from khoj.utils import state, constants - # Initialize Router api = APIRouter() logger = logging.getLogger(__name__) diff --git a/src/khoj/routers/api_beta.py b/src/khoj/routers/api_beta.py index e05b53dd..eefd09fa 100644 --- a/src/khoj/routers/api_beta.py +++ b/src/khoj/routers/api_beta.py @@ -17,7 +17,7 @@ from khoj.processor.conversation.gpt import ( understand, summarize, ) -from khoj.utils.config import SearchType +from khoj.utils.state import SearchType from khoj.utils.helpers import get_from_dict, resolve_absolute_path from khoj.utils import state diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 7e6abc1e..7a38bfdf 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -8,6 +8,7 @@ import torch from pathlib import Path # Internal Packages +from khoj.utils import config as utils_config from khoj.utils.config import SearchModels, ProcessorConfigModel from khoj.utils.helpers import LRU from khoj.utils.rawconfig import FullConfig @@ -23,6 +24,7 @@ port: int = None cli_args: List[str] = None query_cache = LRU() search_index_lock = threading.Lock() +SearchType = utils_config.SearchType if torch.cuda.is_available(): # Use CUDA GPU