Improve handling embedding model config from admin interface

- Allow server to start if loading embedding model fails with an error.
  This allows fixing the embedding model config via admin panel.

  Previously server failed to start if embedding model was configured
  incorrectly. This prevented fixing the model config via admin panel.

- Convert boolean string in config json to actual booleans when passed
  via admin panel as json before passing to model, query configs

- Only create default model if no search model configured by admin.
  Return first created search model if its been configured by admin.
This commit is contained in:
Debanjum 2024-10-29 13:15:47 -07:00
parent 358a6ce95d
commit d44e68ba01
4 changed files with 17 additions and 7 deletions

View file

@ -262,7 +262,7 @@ def configure_server(
initialize_content(regenerate, search_type, user) initialize_content(regenerate, search_type, user)
except Exception as e: except Exception as e:
raise e logger.error(f"Failed to load some search models: {e}", exc_info=True)
def setup_default_agent(user: KhojUser): def setup_default_agent(user: KhojUser):

View file

@ -476,10 +476,11 @@ def get_default_search_model() -> SearchModelConfig:
if default_search_model: if default_search_model:
return default_search_model return default_search_model
elif SearchModelConfig.objects.count() > 0:
return SearchModelConfig.objects.first()
else: else:
SearchModelConfig.objects.create() SearchModelConfig.objects.create()
return SearchModelConfig.objects.filter(name="default").first()
return SearchModelConfig.objects.first()
def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig: def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig:

View file

@ -13,7 +13,7 @@ from tenacity import (
) )
from torch import nn from torch import nn
from khoj.utils.helpers import get_device, merge_dicts, timer from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
from khoj.utils.rawconfig import SearchResponse from khoj.utils.rawconfig import SearchResponse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,9 +31,9 @@ class EmbeddingsModel:
): ):
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True} default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True} default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs) self.query_encode_kwargs = merge_dicts(fix_json_dict(query_encode_kwargs), default_query_encode_kwargs)
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs) self.docs_encode_kwargs = merge_dicts(fix_json_dict(docs_encode_kwargs), default_docs_encode_kwargs)
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()}) self.model_kwargs = merge_dicts(fix_json_dict(model_kwargs), {"device": get_device()})
self.model_name = model_name self.model_name = model_name
self.inference_endpoint = embeddings_inference_endpoint self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key self.api_key = embeddings_inference_endpoint_api_key

View file

@ -101,6 +101,15 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict return merged_dict
def fix_json_dict(json_dict: dict) -> dict:
for k, v in json_dict.items():
if v == "True" or v == "False":
json_dict[k] = v == "True"
if isinstance(v, dict):
json_dict[k] = fix_json_dict(v)
return json_dict
def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]: def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]:
"Get file type from file mime type" "Get file type from file mime type"