Use multiple threads to generate model embeddings. Other minor formating

This commit is contained in:
Debanjum Singh Solanky 2021-09-29 20:47:58 -07:00
parent e22e0b41e3
commit 352d2930ee
3 changed files with 5 additions and 0 deletions

View file

@ -22,6 +22,7 @@ from utils.config import AsymmetricSearchModel
def initialize_model(): def initialize_model():
"Initialize model for assymetric semantic search. That is, where query smaller than results" "Initialize model for assymetric semantic search. That is, where query smaller than results"
torch.set_num_threads(4)
bi_encoder = SentenceTransformer('sentence-transformers/msmarco-MiniLM-L-6-v3') # The bi-encoder encodes all entries to use for semantic search bi_encoder = SentenceTransformer('sentence-transformers/msmarco-MiniLM-L-6-v3') # The bi-encoder encodes all entries to use for semantic search
top_k = 100 # Number of entries we want to retrieve with the bi-encoder top_k = 100 # Number of entries we want to retrieve with the bi-encoder
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality

View file

@ -19,6 +19,7 @@ from processor.ledger.beancount_to_jsonl import beancount_to_jsonl
def initialize_model(): def initialize_model():
"Initialize model for symetric semantic search. That is, where query of similar size to results" "Initialize model for symetric semantic search. That is, where query of similar size to results"
torch.set_num_threads(4)
bi_encoder = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2') # The encoder encodes all entries to use for semantic search bi_encoder = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2') # The encoder encodes all entries to use for semantic search
top_k = 100 # Number of entries we want to retrieve with the bi-encoder top_k = 100 # Number of entries we want to retrieve with the bi-encoder
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality

View file

@ -28,6 +28,7 @@ def test_search_with_invalid_search_type():
assert response.status_code == 422 assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
def test_search_with_valid_search_type(): def test_search_with_valid_search_type():
# Arrange # Arrange
for search_type in ["notes", "ledger", "music", "image"]: for search_type in ["notes", "ledger", "music", "image"]:
@ -37,6 +38,7 @@ def test_search_with_valid_search_type():
assert response.status_code == 200 assert response.status_code == 200
# ----------------------------------------------------------------------------------------------------
def test_regenerate_with_invalid_search_type(): def test_regenerate_with_invalid_search_type():
# Act # Act
response = client.get(f"/regenerate?t=invalid_search_type") response = client.get(f"/regenerate?t=invalid_search_type")
@ -45,6 +47,7 @@ def test_regenerate_with_invalid_search_type():
assert response.status_code == 422 assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
def test_regenerate_with_valid_search_type(): def test_regenerate_with_valid_search_type():
# Arrange # Arrange
for search_type in ["notes", "ledger", "music", "image"]: for search_type in ["notes", "ledger", "music", "image"]: