mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Load model from HuggingFace if model_directory unset in config YAML
- Do not save/load the model to/from disk when model_directory unset in config.yml - Add symmetric search default config to cli.py
This commit is contained in:
parent
510faa1904
commit
c64e0c2965
2 changed files with 15 additions and 6 deletions
|
@ -77,14 +77,22 @@ default_config = {
|
||||||
},
|
},
|
||||||
'search-type':
|
'search-type':
|
||||||
{
|
{
|
||||||
'asymmetric':
|
'symmetric':
|
||||||
|
{
|
||||||
|
'encoder': "sentence-transformers/paraphrase-MiniLM-L6-v2",
|
||||||
|
'cross-encoder': "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||||
|
'model_directory': None
|
||||||
|
},
|
||||||
|
'asymmetric':
|
||||||
{
|
{
|
||||||
'encoder': "sentence-transformers/msmarco-MiniLM-L-6-v3",
|
'encoder': "sentence-transformers/msmarco-MiniLM-L-6-v3",
|
||||||
'cross-encoder': "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
'cross-encoder': "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||||
|
'model_directory': None
|
||||||
},
|
},
|
||||||
'image':
|
'image':
|
||||||
{
|
{
|
||||||
'encoder': "clip-ViT-B-32"
|
'encoder': "clip-ViT-B-32",
|
||||||
|
'model_directory': None
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
'processor':
|
'processor':
|
||||||
|
|
|
@ -39,14 +39,15 @@ def merge_dicts(priority_dict, default_dict):
|
||||||
def load_model(model_name, model_dir, model_type):
|
def load_model(model_name, model_dir, model_type):
|
||||||
"Load model from disk or huggingface"
|
"Load model from disk or huggingface"
|
||||||
# Construct model path
|
# Construct model path
|
||||||
model_path = join(model_dir, model_name.replace("/", "_"))
|
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
|
||||||
|
|
||||||
# Load model from model_path if it exists there
|
# Load model from model_path if it exists there
|
||||||
if resolve_absolute_path(model_path).exists():
|
if model_path is not None and resolve_absolute_path(model_path).exists():
|
||||||
model = model_type(get_absolute_path(model_path))
|
model = model_type(get_absolute_path(model_path))
|
||||||
# Else load the model from the model_name
|
# Else load the model from the model_name
|
||||||
else:
|
else:
|
||||||
model = model_type(model_name)
|
model = model_type(model_name)
|
||||||
model.save(model_path)
|
if model_path is not None:
|
||||||
|
model.save(model_path)
|
||||||
|
|
||||||
return model
|
return model
|
Loading…
Add table
Reference in a new issue