Improve handling embedding model config from admin interface

- Allow server to start if loading embedding model fails with an error.
  This allows fixing the embedding model config via admin panel.

  Previously server failed to start if embedding model was configured
  incorrectly. This prevented fixing the model config via admin panel.

- Convert boolean string in config json to actual booleans when passed
  via admin panel as json before passing to model, query configs

- Only create default model if no search model configured by admin.
  Return first created search model if its been configured by admin.
This commit is contained in:
Debanjum 2024-10-29 13:15:47 -07:00
parent 358a6ce95d
commit d44e68ba01
4 changed files with 17 additions and 7 deletions

View file

@ -262,7 +262,7 @@ def configure_server(
initialize_content(regenerate, search_type, user)
except Exception as e:
raise e
logger.error(f"Failed to load some search models: {e}", exc_info=True)
def setup_default_agent(user: KhojUser):

View file

@ -476,10 +476,11 @@ def get_default_search_model() -> SearchModelConfig:
if default_search_model:
return default_search_model
elif SearchModelConfig.objects.count() > 0:
return SearchModelConfig.objects.first()
else:
SearchModelConfig.objects.create()
return SearchModelConfig.objects.first()
return SearchModelConfig.objects.filter(name="default").first()
def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig:

View file

@ -13,7 +13,7 @@ from tenacity import (
)
from torch import nn
from khoj.utils.helpers import get_device, merge_dicts, timer
from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
from khoj.utils.rawconfig import SearchResponse
logger = logging.getLogger(__name__)
@ -31,9 +31,9 @@ class EmbeddingsModel:
):
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
self.query_encode_kwargs = merge_dicts(fix_json_dict(query_encode_kwargs), default_query_encode_kwargs)
self.docs_encode_kwargs = merge_dicts(fix_json_dict(docs_encode_kwargs), default_docs_encode_kwargs)
self.model_kwargs = merge_dicts(fix_json_dict(model_kwargs), {"device": get_device()})
self.model_name = model_name
self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key

View file

@ -101,6 +101,15 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict
def fix_json_dict(json_dict: dict) -> dict:
for k, v in json_dict.items():
if v == "True" or v == "False":
json_dict[k] = v == "True"
if isinstance(v, dict):
json_dict[k] = fix_json_dict(v)
return json_dict
def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]:
"Get file type from file mime type"