From d44e68ba01346c90f518bce2996c1903f18ecabd Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 29 Oct 2024 13:15:47 -0700 Subject: [PATCH] Improve handling embedding model config from admin interface - Allow server to start if loading embedding model fails with an error. This allows fixing the embedding model config via admin panel. Previously server failed to start if embedding model was configured incorrectly. This prevented fixing the model config via admin panel. - Convert boolean string in config json to actual booleans when passed via admin panel as json before passing to model, query configs - Only create default model if no search model configured by admin. Return first created search model if its been configured by admin. --- src/khoj/configure.py | 2 +- src/khoj/database/adapters/__init__.py | 5 +++-- src/khoj/processor/embeddings.py | 8 ++++---- src/khoj/utils/helpers.py | 9 +++++++++ 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 1454b164..df0760fb 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -262,7 +262,7 @@ def configure_server( initialize_content(regenerate, search_type, user) except Exception as e: - raise e + logger.error(f"Failed to load some search models: {e}", exc_info=True) def setup_default_agent(user: KhojUser): diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 5db857ee..f102648b 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -476,10 +476,11 @@ def get_default_search_model() -> SearchModelConfig: if default_search_model: return default_search_model + elif SearchModelConfig.objects.count() > 0: + return SearchModelConfig.objects.first() else: SearchModelConfig.objects.create() - - return SearchModelConfig.objects.first() + return SearchModelConfig.objects.filter(name="default").first() def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig: diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index a19d85fa..b224e7f5 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -13,7 +13,7 @@ from tenacity import ( ) from torch import nn -from khoj.utils.helpers import get_device, merge_dicts, timer +from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer from khoj.utils.rawconfig import SearchResponse logger = logging.getLogger(__name__) @@ -31,9 +31,9 @@ class EmbeddingsModel: ): default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True} default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True} - self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs) - self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs) - self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()}) + self.query_encode_kwargs = merge_dicts(fix_json_dict(query_encode_kwargs), default_query_encode_kwargs) + self.docs_encode_kwargs = merge_dicts(fix_json_dict(docs_encode_kwargs), default_docs_encode_kwargs) + self.model_kwargs = merge_dicts(fix_json_dict(model_kwargs), {"device": get_device()}) self.model_name = model_name self.inference_endpoint = embeddings_inference_endpoint self.api_key = embeddings_inference_endpoint_api_key diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 664bbde9..6bfb3594 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -101,6 +101,15 @@ def merge_dicts(priority_dict: dict, default_dict: dict): return merged_dict +def fix_json_dict(json_dict: dict) -> dict: + for k, v in json_dict.items(): + if v == "True" or v == "False": + json_dict[k] = v == "True" + if isinstance(v, dict): + json_dict[k] = fix_json_dict(v) + return json_dict + + def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]: "Get file type from file mime type"