mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Add Search Config for Symmetric Model. Save Model to Disk
This commit is contained in:
parent
b63026d97c
commit
934ec233b0
4 changed files with 33 additions and 9 deletions
|
@ -24,6 +24,11 @@ content-type:
|
|||
embeddings-file: "tests/data/.song_embeddings.pt"
|
||||
|
||||
search-type:
|
||||
symmetric:
|
||||
encoder: "sentence-transformers/paraphrase-MiniLM-L6-v2"
|
||||
cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
model_directory: "tests/data/.symmetric"
|
||||
|
||||
asymmetric:
|
||||
encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3"
|
||||
cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
|
|
|
@ -140,7 +140,7 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None
|
|||
# Initialize Ledger Search
|
||||
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||
# Extract Entries, Generate Ledger Embeddings
|
||||
model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, regenerate=regenerate, verbose=verbose)
|
||||
model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose)
|
||||
|
||||
# Initialize Image Search
|
||||
if (t == SearchType.Image or t == None) and config.content_type.image:
|
||||
|
|
|
@ -10,18 +10,31 @@ import torch
|
|||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
||||
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.rawconfig import TextSearchConfig
|
||||
from src.utils.rawconfig import SymmetricConfig, TextSearchConfig
|
||||
|
||||
|
||||
def initialize_model():
|
||||
def initialize_model(search_config: SymmetricConfig):
|
||||
"Initialize model for symmetric semantic search. That is, where query of similar size to results"
|
||||
torch.set_num_threads(4)
|
||||
bi_encoder = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2') # The encoder encodes all entries to use for semantic search
|
||||
top_k = 30 # Number of entries we want to retrieve with the bi-encoder
|
||||
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality
|
||||
|
||||
# Number of entries we want to retrieve with the bi-encoder
|
||||
top_k = 30
|
||||
|
||||
# The bi-encoder encodes all entries to use for semantic search
|
||||
bi_encoder = load_model(
|
||||
model_dir = search_config.model_directory,
|
||||
model_name = search_config.encoder,
|
||||
model_type = SentenceTransformer)
|
||||
|
||||
# The cross-encoder re-ranks the results to improve quality
|
||||
cross_encoder = load_model(
|
||||
model_dir = search_config.model_directory,
|
||||
model_name = search_config.cross_encoder,
|
||||
model_type = CrossEncoder)
|
||||
|
||||
return bi_encoder, cross_encoder, top_k
|
||||
|
||||
|
||||
|
@ -141,9 +154,9 @@ def collate_results(hits, entries, count=5):
|
|||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(config: TextSearchConfig, regenerate: bool, verbose: bool) -> TextSearchModel:
|
||||
def setup(config: TextSearchConfig, search_config: SymmetricConfig, regenerate: bool, verbose: bool) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model()
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
# Map notes in Org-Mode files to (compressed) JSONL formatted file
|
||||
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:
|
||||
|
|
|
@ -37,6 +37,11 @@ class ContentTypeConfig(ConfigBase):
|
|||
image: Optional[ImageSearchConfig]
|
||||
music: Optional[TextSearchConfig]
|
||||
|
||||
class SymmetricConfig(ConfigBase):
|
||||
encoder: Optional[str]
|
||||
cross_encoder: Optional[str]
|
||||
model_directory: Optional[Path]
|
||||
|
||||
class AsymmetricConfig(ConfigBase):
|
||||
encoder: Optional[str]
|
||||
cross_encoder: Optional[str]
|
||||
|
@ -47,6 +52,7 @@ class ImageSearchTypeConfig(ConfigBase):
|
|||
|
||||
class SearchTypeConfig(ConfigBase):
|
||||
asymmetric: Optional[AsymmetricConfig]
|
||||
symmetric: Optional[SymmetricConfig]
|
||||
image: Optional[ImageSearchTypeConfig]
|
||||
|
||||
class ConversationProcessorConfig(ConfigBase):
|
||||
|
|
Loading…
Reference in a new issue