mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Improve handling embedding model config from admin interface
- Allow server to start if loading embedding model fails with an error. This allows fixing the embedding model config via admin panel. Previously server failed to start if embedding model was configured incorrectly. This prevented fixing the model config via admin panel. - Convert boolean string in config json to actual booleans when passed via admin panel as json before passing to model, query configs - Only create default model if no search model configured by admin. Return first created search model if its been configured by admin.
This commit is contained in:
parent
358a6ce95d
commit
d44e68ba01
4 changed files with 17 additions and 7 deletions
|
@ -262,7 +262,7 @@ def configure_server(
|
||||||
initialize_content(regenerate, search_type, user)
|
initialize_content(regenerate, search_type, user)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
logger.error(f"Failed to load some search models: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
def setup_default_agent(user: KhojUser):
|
def setup_default_agent(user: KhojUser):
|
||||||
|
|
|
@ -476,10 +476,11 @@ def get_default_search_model() -> SearchModelConfig:
|
||||||
|
|
||||||
if default_search_model:
|
if default_search_model:
|
||||||
return default_search_model
|
return default_search_model
|
||||||
|
elif SearchModelConfig.objects.count() > 0:
|
||||||
|
return SearchModelConfig.objects.first()
|
||||||
else:
|
else:
|
||||||
SearchModelConfig.objects.create()
|
SearchModelConfig.objects.create()
|
||||||
|
return SearchModelConfig.objects.filter(name="default").first()
|
||||||
return SearchModelConfig.objects.first()
|
|
||||||
|
|
||||||
|
|
||||||
def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig:
|
def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig:
|
||||||
|
|
|
@ -13,7 +13,7 @@ from tenacity import (
|
||||||
)
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from khoj.utils.helpers import get_device, merge_dicts, timer
|
from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
|
||||||
from khoj.utils.rawconfig import SearchResponse
|
from khoj.utils.rawconfig import SearchResponse
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -31,9 +31,9 @@ class EmbeddingsModel:
|
||||||
):
|
):
|
||||||
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
|
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
|
||||||
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
|
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
|
||||||
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
|
self.query_encode_kwargs = merge_dicts(fix_json_dict(query_encode_kwargs), default_query_encode_kwargs)
|
||||||
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
|
self.docs_encode_kwargs = merge_dicts(fix_json_dict(docs_encode_kwargs), default_docs_encode_kwargs)
|
||||||
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
|
self.model_kwargs = merge_dicts(fix_json_dict(model_kwargs), {"device": get_device()})
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.inference_endpoint = embeddings_inference_endpoint
|
self.inference_endpoint = embeddings_inference_endpoint
|
||||||
self.api_key = embeddings_inference_endpoint_api_key
|
self.api_key = embeddings_inference_endpoint_api_key
|
||||||
|
|
|
@ -101,6 +101,15 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
|
||||||
return merged_dict
|
return merged_dict
|
||||||
|
|
||||||
|
|
||||||
|
def fix_json_dict(json_dict: dict) -> dict:
|
||||||
|
for k, v in json_dict.items():
|
||||||
|
if v == "True" or v == "False":
|
||||||
|
json_dict[k] = v == "True"
|
||||||
|
if isinstance(v, dict):
|
||||||
|
json_dict[k] = fix_json_dict(v)
|
||||||
|
return json_dict
|
||||||
|
|
||||||
|
|
||||||
def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]:
|
def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]:
|
||||||
"Get file type from file mime type"
|
"Get file type from file mime type"
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue