diff --git a/sample_config.yml b/sample_config.yml index 2203a694..ee19c8ee 100644 --- a/sample_config.yml +++ b/sample_config.yml @@ -27,6 +27,7 @@ search-type: asymmetric: encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3" cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2" + model_directory: "tests/data/.asymmetric" image: encoder: "clip-ViT-B-32" diff --git a/src/main.py b/src/main.py index dea7cf63..71397486 100644 --- a/src/main.py +++ b/src/main.py @@ -130,12 +130,12 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None # Initialize Org Notes Search if (t == SearchType.Notes or t == None) and config.content_type.org: # Extract Entries, Generate Notes Embeddings - model.notes_search = asymmetric.setup(config.content_type.org, regenerate=regenerate, verbose=verbose) + model.notes_search = asymmetric.setup(config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) # Initialize Org Music Search if (t == SearchType.Music or t == None) and config.content_type.music: # Extract Entries, Generate Music Embeddings - model.music_search = asymmetric.setup(config.content_type.music, regenerate=regenerate, verbose=verbose) + model.music_search = asymmetric.setup(config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) # Initialize Ledger Search if (t == SearchType.Ledger or t == None) and config.content_type.ledger: diff --git a/src/search_type/asymmetric.py b/src/search_type/asymmetric.py index 416cf7e2..670cc39e 100644 --- a/src/search_type/asymmetric.py +++ b/src/search_type/asymmetric.py @@ -12,18 +12,31 @@ import torch from sentence_transformers import SentenceTransformer, CrossEncoder, util # Internal Packages -from src.utils.helpers import get_absolute_path, resolve_absolute_path +from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model from src.processor.org_mode.org_to_jsonl import org_to_jsonl from src.utils.config import TextSearchModel -from src.utils.rawconfig import TextSearchConfig +from src.utils.rawconfig import AsymmetricConfig, TextSearchConfig -def initialize_model(): +def initialize_model(search_config: AsymmetricConfig): "Initialize model for assymetric semantic search. That is, where query smaller than results" torch.set_num_threads(4) - bi_encoder = SentenceTransformer('sentence-transformers/msmarco-MiniLM-L-6-v3') # The bi-encoder encodes all entries to use for semantic search - top_k = 30 # Number of entries we want to retrieve with the bi-encoder - cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality + + # Number of entries we want to retrieve with the bi-encoder + top_k = 30 + + # The bi-encoder encodes all entries to use for semantic search + bi_encoder = load_model( + model_dir = search_config.model_directory, + model_name = search_config.encoder, + model_type = SentenceTransformer) + + # The cross-encoder re-ranks the results to improve quality + cross_encoder = load_model( + model_dir = search_config.model_directory, + model_name = search_config.cross_encoder, + model_type = CrossEncoder) + return bi_encoder, cross_encoder, top_k @@ -149,9 +162,9 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(config: TextSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel: +def setup(config: TextSearchConfig, search_config: AsymmetricConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel: # Initialize Model - bi_encoder, cross_encoder, top_k = initialize_model() + bi_encoder, cross_encoder, top_k = initialize_model(search_config) # Map notes in Org-Mode files to (compressed) JSONL formatted file if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate: diff --git a/src/utils/helpers.py b/src/utils/helpers.py index e2a3b1fe..27bb4c3f 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -1,4 +1,6 @@ +# Standard Packages import pathlib +from os.path import join def is_none_or_empty(item): @@ -32,3 +34,19 @@ def merge_dicts(priority_dict, default_dict): if k not in priority_dict: merged_dict[k] = default_dict[k] return merged_dict + + +def load_model(model_name, model_dir, model_type): + "Load model from disk or huggingface" + # Construct model path + model_path = join(model_dir, model_name.replace("/", "_")) + + # Load model from model_path if it exists there + if resolve_absolute_path(model_path).exists(): + model = model_type(get_absolute_path(model_path)) + # Else load the model from the model_name + else: + model = model_type(model_name) + model.save(model_path) + + return model \ No newline at end of file diff --git a/src/utils/rawconfig.py b/src/utils/rawconfig.py index b28e9a2d..44e4cce1 100644 --- a/src/utils/rawconfig.py +++ b/src/utils/rawconfig.py @@ -40,6 +40,7 @@ class ContentTypeConfig(ConfigBase): class AsymmetricConfig(ConfigBase): encoder: Optional[str] cross_encoder: Optional[str] + model_directory: Optional[Path] class ImageSearchTypeConfig(ConfigBase): encoder: Optional[str]