Add Search Config for Symmetric Model. Save Model to Disk

This commit is contained in:
Debanjum Singh Solanky 2022-01-14 16:46:56 -05:00
parent b63026d97c
commit 934ec233b0
4 changed files with 33 additions and 9 deletions

View file

@ -24,6 +24,11 @@ content-type:
embeddings-file: "tests/data/.song_embeddings.pt" embeddings-file: "tests/data/.song_embeddings.pt"
search-type: search-type:
symmetric:
encoder: "sentence-transformers/paraphrase-MiniLM-L6-v2"
cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2"
model_directory: "tests/data/.symmetric"
asymmetric: asymmetric:
encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3" encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3"
cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2" cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2"

View file

@ -140,7 +140,7 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None
# Initialize Ledger Search # Initialize Ledger Search
if (t == SearchType.Ledger or t == None) and config.content_type.ledger: if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
# Extract Entries, Generate Ledger Embeddings # Extract Entries, Generate Ledger Embeddings
model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, regenerate=regenerate, verbose=verbose) model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose)
# Initialize Image Search # Initialize Image Search
if (t == SearchType.Image or t == None) and config.content_type.image: if (t == SearchType.Image or t == None) and config.content_type.image:

View file

@ -10,18 +10,31 @@ import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util from sentence_transformers import SentenceTransformer, CrossEncoder, util
# Internal Packages # Internal Packages
from src.utils.helpers import get_absolute_path, resolve_absolute_path from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl
from src.utils.config import TextSearchModel from src.utils.config import TextSearchModel
from src.utils.rawconfig import TextSearchConfig from src.utils.rawconfig import SymmetricConfig, TextSearchConfig
def initialize_model(): def initialize_model(search_config: SymmetricConfig):
"Initialize model for symmetric semantic search. That is, where query of similar size to results" "Initialize model for symmetric semantic search. That is, where query of similar size to results"
torch.set_num_threads(4) torch.set_num_threads(4)
bi_encoder = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2') # The encoder encodes all entries to use for semantic search
top_k = 30 # Number of entries we want to retrieve with the bi-encoder # Number of entries we want to retrieve with the bi-encoder
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality top_k = 30
# The bi-encoder encodes all entries to use for semantic search
bi_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.encoder,
model_type = SentenceTransformer)
# The cross-encoder re-ranks the results to improve quality
cross_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.cross_encoder,
model_type = CrossEncoder)
return bi_encoder, cross_encoder, top_k return bi_encoder, cross_encoder, top_k
@ -141,9 +154,9 @@ def collate_results(hits, entries, count=5):
in hits[0:count]] in hits[0:count]]
def setup(config: TextSearchConfig, regenerate: bool, verbose: bool) -> TextSearchModel: def setup(config: TextSearchConfig, search_config: SymmetricConfig, regenerate: bool, verbose: bool) -> TextSearchModel:
# Initialize Model # Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model() bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in Org-Mode files to (compressed) JSONL formatted file # Map notes in Org-Mode files to (compressed) JSONL formatted file
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate: if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:

View file

@ -37,6 +37,11 @@ class ContentTypeConfig(ConfigBase):
image: Optional[ImageSearchConfig] image: Optional[ImageSearchConfig]
music: Optional[TextSearchConfig] music: Optional[TextSearchConfig]
class SymmetricConfig(ConfigBase):
encoder: Optional[str]
cross_encoder: Optional[str]
model_directory: Optional[Path]
class AsymmetricConfig(ConfigBase): class AsymmetricConfig(ConfigBase):
encoder: Optional[str] encoder: Optional[str]
cross_encoder: Optional[str] cross_encoder: Optional[str]
@ -47,6 +52,7 @@ class ImageSearchTypeConfig(ConfigBase):
class SearchTypeConfig(ConfigBase): class SearchTypeConfig(ConfigBase):
asymmetric: Optional[AsymmetricConfig] asymmetric: Optional[AsymmetricConfig]
symmetric: Optional[SymmetricConfig]
image: Optional[ImageSearchTypeConfig] image: Optional[ImageSearchTypeConfig]
class ConversationProcessorConfig(ConfigBase): class ConversationProcessorConfig(ConfigBase):