From 352d2930ee89645371027464d9d27b47ea200219 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 29 Sep 2021 20:47:58 -0700 Subject: [PATCH] Use multiple threads to generate model embeddings. Other minor formating --- src/search_type/asymmetric.py | 1 + src/search_type/symmetric_ledger.py | 1 + src/tests/test_main.py | 3 +++ 3 files changed, 5 insertions(+) diff --git a/src/search_type/asymmetric.py b/src/search_type/asymmetric.py index e5ca2858..fd5c688e 100644 --- a/src/search_type/asymmetric.py +++ b/src/search_type/asymmetric.py @@ -22,6 +22,7 @@ from utils.config import AsymmetricSearchModel def initialize_model(): "Initialize model for assymetric semantic search. That is, where query smaller than results" + torch.set_num_threads(4) bi_encoder = SentenceTransformer('sentence-transformers/msmarco-MiniLM-L-6-v3') # The bi-encoder encodes all entries to use for semantic search top_k = 100 # Number of entries we want to retrieve with the bi-encoder cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality diff --git a/src/search_type/symmetric_ledger.py b/src/search_type/symmetric_ledger.py index d19f82eb..a5e0c04c 100644 --- a/src/search_type/symmetric_ledger.py +++ b/src/search_type/symmetric_ledger.py @@ -19,6 +19,7 @@ from processor.ledger.beancount_to_jsonl import beancount_to_jsonl def initialize_model(): "Initialize model for symetric semantic search. That is, where query of similar size to results" + torch.set_num_threads(4) bi_encoder = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2') # The encoder encodes all entries to use for semantic search top_k = 100 # Number of entries we want to retrieve with the bi-encoder cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality diff --git a/src/tests/test_main.py b/src/tests/test_main.py index d8221568..0d2a8bee 100644 --- a/src/tests/test_main.py +++ b/src/tests/test_main.py @@ -28,6 +28,7 @@ def test_search_with_invalid_search_type(): assert response.status_code == 422 +# ---------------------------------------------------------------------------------------------------- def test_search_with_valid_search_type(): # Arrange for search_type in ["notes", "ledger", "music", "image"]: @@ -37,6 +38,7 @@ def test_search_with_valid_search_type(): assert response.status_code == 200 +# ---------------------------------------------------------------------------------------------------- def test_regenerate_with_invalid_search_type(): # Act response = client.get(f"/regenerate?t=invalid_search_type") @@ -45,6 +47,7 @@ def test_regenerate_with_invalid_search_type(): assert response.status_code == 422 +# ---------------------------------------------------------------------------------------------------- def test_regenerate_with_valid_search_type(): # Arrange for search_type in ["notes", "ledger", "music", "image"]: