mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Add Search Config for Symmetric Model. Save Model to Disk
This commit is contained in:
parent
b63026d97c
commit
934ec233b0
4 changed files with 33 additions and 9 deletions
|
@ -24,6 +24,11 @@ content-type:
|
||||||
embeddings-file: "tests/data/.song_embeddings.pt"
|
embeddings-file: "tests/data/.song_embeddings.pt"
|
||||||
|
|
||||||
search-type:
|
search-type:
|
||||||
|
symmetric:
|
||||||
|
encoder: "sentence-transformers/paraphrase-MiniLM-L6-v2"
|
||||||
|
cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
|
model_directory: "tests/data/.symmetric"
|
||||||
|
|
||||||
asymmetric:
|
asymmetric:
|
||||||
encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3"
|
encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3"
|
||||||
cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
|
|
|
@ -140,7 +140,7 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None
|
||||||
# Initialize Ledger Search
|
# Initialize Ledger Search
|
||||||
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||||
# Extract Entries, Generate Ledger Embeddings
|
# Extract Entries, Generate Ledger Embeddings
|
||||||
model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, regenerate=regenerate, verbose=verbose)
|
model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose)
|
||||||
|
|
||||||
# Initialize Image Search
|
# Initialize Image Search
|
||||||
if (t == SearchType.Image or t == None) and config.content_type.image:
|
if (t == SearchType.Image or t == None) and config.content_type.image:
|
||||||
|
|
|
@ -10,18 +10,31 @@ import torch
|
||||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path
|
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
||||||
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl
|
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl
|
||||||
from src.utils.config import TextSearchModel
|
from src.utils.config import TextSearchModel
|
||||||
from src.utils.rawconfig import TextSearchConfig
|
from src.utils.rawconfig import SymmetricConfig, TextSearchConfig
|
||||||
|
|
||||||
|
|
||||||
def initialize_model():
|
def initialize_model(search_config: SymmetricConfig):
|
||||||
"Initialize model for symmetric semantic search. That is, where query of similar size to results"
|
"Initialize model for symmetric semantic search. That is, where query of similar size to results"
|
||||||
torch.set_num_threads(4)
|
torch.set_num_threads(4)
|
||||||
bi_encoder = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2') # The encoder encodes all entries to use for semantic search
|
|
||||||
top_k = 30 # Number of entries we want to retrieve with the bi-encoder
|
# Number of entries we want to retrieve with the bi-encoder
|
||||||
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality
|
top_k = 30
|
||||||
|
|
||||||
|
# The bi-encoder encodes all entries to use for semantic search
|
||||||
|
bi_encoder = load_model(
|
||||||
|
model_dir = search_config.model_directory,
|
||||||
|
model_name = search_config.encoder,
|
||||||
|
model_type = SentenceTransformer)
|
||||||
|
|
||||||
|
# The cross-encoder re-ranks the results to improve quality
|
||||||
|
cross_encoder = load_model(
|
||||||
|
model_dir = search_config.model_directory,
|
||||||
|
model_name = search_config.cross_encoder,
|
||||||
|
model_type = CrossEncoder)
|
||||||
|
|
||||||
return bi_encoder, cross_encoder, top_k
|
return bi_encoder, cross_encoder, top_k
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,9 +154,9 @@ def collate_results(hits, entries, count=5):
|
||||||
in hits[0:count]]
|
in hits[0:count]]
|
||||||
|
|
||||||
|
|
||||||
def setup(config: TextSearchConfig, regenerate: bool, verbose: bool) -> TextSearchModel:
|
def setup(config: TextSearchConfig, search_config: SymmetricConfig, regenerate: bool, verbose: bool) -> TextSearchModel:
|
||||||
# Initialize Model
|
# Initialize Model
|
||||||
bi_encoder, cross_encoder, top_k = initialize_model()
|
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||||
|
|
||||||
# Map notes in Org-Mode files to (compressed) JSONL formatted file
|
# Map notes in Org-Mode files to (compressed) JSONL formatted file
|
||||||
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:
|
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:
|
||||||
|
|
|
@ -37,6 +37,11 @@ class ContentTypeConfig(ConfigBase):
|
||||||
image: Optional[ImageSearchConfig]
|
image: Optional[ImageSearchConfig]
|
||||||
music: Optional[TextSearchConfig]
|
music: Optional[TextSearchConfig]
|
||||||
|
|
||||||
|
class SymmetricConfig(ConfigBase):
|
||||||
|
encoder: Optional[str]
|
||||||
|
cross_encoder: Optional[str]
|
||||||
|
model_directory: Optional[Path]
|
||||||
|
|
||||||
class AsymmetricConfig(ConfigBase):
|
class AsymmetricConfig(ConfigBase):
|
||||||
encoder: Optional[str]
|
encoder: Optional[str]
|
||||||
cross_encoder: Optional[str]
|
cross_encoder: Optional[str]
|
||||||
|
@ -47,6 +52,7 @@ class ImageSearchTypeConfig(ConfigBase):
|
||||||
|
|
||||||
class SearchTypeConfig(ConfigBase):
|
class SearchTypeConfig(ConfigBase):
|
||||||
asymmetric: Optional[AsymmetricConfig]
|
asymmetric: Optional[AsymmetricConfig]
|
||||||
|
symmetric: Optional[SymmetricConfig]
|
||||||
image: Optional[ImageSearchTypeConfig]
|
image: Optional[ImageSearchTypeConfig]
|
||||||
|
|
||||||
class ConversationProcessorConfig(ConfigBase):
|
class ConversationProcessorConfig(ConfigBase):
|
||||||
|
|
Loading…
Reference in a new issue