diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 1454b164..df0760fb 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -262,7 +262,7 @@ def configure_server( initialize_content(regenerate, search_type, user) except Exception as e: - raise e + logger.error(f"Failed to load some search models: {e}", exc_info=True) def setup_default_agent(user: KhojUser): diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 5db857ee..f102648b 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -476,10 +476,11 @@ def get_default_search_model() -> SearchModelConfig: if default_search_model: return default_search_model + elif SearchModelConfig.objects.count() > 0: + return SearchModelConfig.objects.first() else: SearchModelConfig.objects.create() - - return SearchModelConfig.objects.first() + return SearchModelConfig.objects.filter(name="default").first() def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig: diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index a19d85fa..b224e7f5 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -13,7 +13,7 @@ from tenacity import ( ) from torch import nn -from khoj.utils.helpers import get_device, merge_dicts, timer +from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer from khoj.utils.rawconfig import SearchResponse logger = logging.getLogger(__name__) @@ -31,9 +31,9 @@ class EmbeddingsModel: ): default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True} default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True} - self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs) - self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs) - self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()}) + self.query_encode_kwargs = merge_dicts(fix_json_dict(query_encode_kwargs), default_query_encode_kwargs) + self.docs_encode_kwargs = merge_dicts(fix_json_dict(docs_encode_kwargs), default_docs_encode_kwargs) + self.model_kwargs = merge_dicts(fix_json_dict(model_kwargs), {"device": get_device()}) self.model_name = model_name self.inference_endpoint = embeddings_inference_endpoint self.api_key = embeddings_inference_endpoint_api_key diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 664bbde9..6bfb3594 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -101,6 +101,15 @@ def merge_dicts(priority_dict: dict, default_dict: dict): return merged_dict +def fix_json_dict(json_dict: dict) -> dict: + for k, v in json_dict.items(): + if v == "True" or v == "False": + json_dict[k] = v == "True" + if isinstance(v, dict): + json_dict[k] = fix_json_dict(v) + return json_dict + + def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]: "Get file type from file mime type"