diff --git a/src/utils/cli.py b/src/utils/cli.py index aa58af8d..505628a5 100644 --- a/src/utils/cli.py +++ b/src/utils/cli.py @@ -77,14 +77,22 @@ default_config = { }, 'search-type': { - 'asymmetric': + 'symmetric': + { + 'encoder': "sentence-transformers/paraphrase-MiniLM-L6-v2", + 'cross-encoder': "cross-encoder/ms-marco-MiniLM-L-6-v2", + 'model_directory': None + }, + 'asymmetric': { 'encoder': "sentence-transformers/msmarco-MiniLM-L-6-v3", - 'cross-encoder': "cross-encoder/ms-marco-MiniLM-L-6-v2" + 'cross-encoder': "cross-encoder/ms-marco-MiniLM-L-6-v2", + 'model_directory': None }, 'image': { - 'encoder': "clip-ViT-B-32" + 'encoder': "clip-ViT-B-32", + 'model_directory': None }, }, 'processor': diff --git a/src/utils/helpers.py b/src/utils/helpers.py index 27bb4c3f..ab4cec7c 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -39,14 +39,15 @@ def merge_dicts(priority_dict, default_dict): def load_model(model_name, model_dir, model_type): "Load model from disk or huggingface" # Construct model path - model_path = join(model_dir, model_name.replace("/", "_")) + model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None # Load model from model_path if it exists there - if resolve_absolute_path(model_path).exists(): + if model_path is not None and resolve_absolute_path(model_path).exists(): model = model_type(get_absolute_path(model_path)) # Else load the model from the model_name else: model = model_type(model_name) - model.save(model_path) + if model_path is not None: + model.save(model_path) return model \ No newline at end of file