Add Search Config for Symmetric Model. Save Model to Disk

This commit is contained in:
Debanjum Singh Solanky 2022-01-14 16:46:56 -05:00
parent b63026d97c
commit 934ec233b0
4 changed files with 33 additions and 9 deletions

View file

@ -24,6 +24,11 @@ content-type:
embeddings-file: "tests/data/.song_embeddings.pt"
search-type:
symmetric:
encoder: "sentence-transformers/paraphrase-MiniLM-L6-v2"
cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2"
model_directory: "tests/data/.symmetric"
asymmetric:
encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3"
cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2"

View file

@ -140,7 +140,7 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None
# Initialize Ledger Search
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
# Extract Entries, Generate Ledger Embeddings
model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, regenerate=regenerate, verbose=verbose)
model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose)
# Initialize Image Search
if (t == SearchType.Image or t == None) and config.content_type.image:

View file

@ -10,18 +10,31 @@ import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util
# Internal Packages
from src.utils.helpers import get_absolute_path, resolve_absolute_path
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl
from src.utils.config import TextSearchModel
from src.utils.rawconfig import TextSearchConfig
from src.utils.rawconfig import SymmetricConfig, TextSearchConfig
def initialize_model():
def initialize_model(search_config: SymmetricConfig):
"Initialize model for symmetric semantic search. That is, where query of similar size to results"
torch.set_num_threads(4)
bi_encoder = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2') # The encoder encodes all entries to use for semantic search
top_k = 30 # Number of entries we want to retrieve with the bi-encoder
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality
# Number of entries we want to retrieve with the bi-encoder
top_k = 30
# The bi-encoder encodes all entries to use for semantic search
bi_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.encoder,
model_type = SentenceTransformer)
# The cross-encoder re-ranks the results to improve quality
cross_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.cross_encoder,
model_type = CrossEncoder)
return bi_encoder, cross_encoder, top_k
@ -141,9 +154,9 @@ def collate_results(hits, entries, count=5):
in hits[0:count]]
def setup(config: TextSearchConfig, regenerate: bool, verbose: bool) -> TextSearchModel:
def setup(config: TextSearchConfig, search_config: SymmetricConfig, regenerate: bool, verbose: bool) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model()
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in Org-Mode files to (compressed) JSONL formatted file
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:

View file

@ -37,6 +37,11 @@ class ContentTypeConfig(ConfigBase):
image: Optional[ImageSearchConfig]
music: Optional[TextSearchConfig]
class SymmetricConfig(ConfigBase):
encoder: Optional[str]
cross_encoder: Optional[str]
model_directory: Optional[Path]
class AsymmetricConfig(ConfigBase):
encoder: Optional[str]
cross_encoder: Optional[str]
@ -47,6 +52,7 @@ class ImageSearchTypeConfig(ConfigBase):
class SearchTypeConfig(ConfigBase):
asymmetric: Optional[AsymmetricConfig]
symmetric: Optional[SymmetricConfig]
image: Optional[ImageSearchTypeConfig]
class ConversationProcessorConfig(ConfigBase):