diff --git a/config/sample_config.yml b/config/sample_config.yml index e743a399..5c8a9f71 100644 --- a/config/sample_config.yml +++ b/config/sample_config.yml @@ -28,7 +28,7 @@ content-type: search-type: symmetric: - encoder: "sentence-transformers/paraphrase-MiniLM-L6-v2" + encoder: "sentence-transformers/all-MiniLM-L6-v2" cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2" model_directory: "/data/models/symmetric" diff --git a/src/search_type/symmetric.py b/src/search_type/symmetric.py index 6d8ffeea..2e26e7c4 100644 --- a/src/search_type/symmetric.py +++ b/src/search_type/symmetric.py @@ -59,7 +59,7 @@ if __name__ == '__main__': parser.add_argument('--dataset', type=str, default="./.dataset", help="Path to dataset to generate index from") parser.add_argument('--column', type=str, default="DATA", help="Name of dataset column to index") parser.add_argument('--num_results', type=int, default=10, help="Number of most suitable matches to show") - parser.add_argument('--model_name', type=str, default='paraphrase-distilroberta-base-v1', help="Specify name of the SentenceTransformer model to use for encoding") + parser.add_argument('--model_name', type=str, default='all-MiniLM-L6-v2', help="Specify name of the SentenceTransformer model to use for encoding") args = parser.parse_args() model = SentenceTransformer(args.model_name) diff --git a/src/utils/cli.py b/src/utils/cli.py index a2d19f1c..1c1ec9dc 100644 --- a/src/utils/cli.py +++ b/src/utils/cli.py @@ -79,7 +79,7 @@ default_config = { { 'symmetric': { - 'encoder': "sentence-transformers/paraphrase-MiniLM-L6-v2", + 'encoder': "sentence-transformers/all-MiniLM-L6-v2", 'cross-encoder': "cross-encoder/ms-marco-MiniLM-L-6-v2", 'model_directory': None }, diff --git a/tests/conftest.py b/tests/conftest.py index a0b07a63..34da236e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ def search_config(tmp_path_factory): search_config = SearchConfig() search_config.asymmetric = SymmetricSearchConfig( - encoder = "sentence-transformers/paraphrase-MiniLM-L6-v2", + encoder = "sentence-transformers/all-MiniLM-L6-v2", cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2", model_directory = model_dir )