Make type of encoder to use for embeddings configurable via khoj.yml

- Previously `model_type' was set in the setup of each `search_type'
  - All encoders were of type `SentenceTransformer'
  - All cross_encoders were of type `CrossEncoder'

- Now `encoder-type' can be configured via the new `encoder_type' field
  in `TextSearchConfig' under `search-type` in `khoj.yml`.

- All the specified `encoder-type' class needs is an `encode' method
  that takes entries and returns embedding vectors
This commit is contained in:
Debanjum Singh Solanky 2023-01-06 15:58:03 -03:00
parent fa92adcf0d
commit 2fe37a090f
4 changed files with 15 additions and 5 deletions

View file

@ -36,7 +36,7 @@ def initialize_model(search_config: ImageSearchConfig):
encoder = load_model( encoder = load_model(
model_dir = search_config.model_directory, model_dir = search_config.model_directory,
model_name = search_config.encoder, model_name = search_config.encoder,
model_type = SentenceTransformer) model_type = search_config.encoder_type or SentenceTransformer)
return encoder return encoder

View file

@ -37,7 +37,7 @@ def initialize_model(search_config: TextSearchConfig):
bi_encoder = load_model( bi_encoder = load_model(
model_dir = search_config.model_directory, model_dir = search_config.model_directory,
model_name = search_config.encoder, model_name = search_config.encoder,
model_type = SentenceTransformer, model_type = search_config.encoder_type or SentenceTransformer,
device=f'{state.device}') device=f'{state.device}')
# The cross-encoder re-ranks the results to improve quality # The cross-encoder re-ranks the results to improve quality

View file

@ -1,5 +1,6 @@
# Standard Packages # Standard Packages
from pathlib import Path from pathlib import Path
from importlib import import_module
import sys import sys
from os.path import join from os.path import join
from collections import OrderedDict from collections import OrderedDict
@ -44,17 +45,18 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict return merged_dict
def load_model(model_name, model_dir, model_type, device:str=None): def load_model(model_name: str, model_type, model_dir=None, device:str=None):
"Load model from disk or huggingface" "Load model from disk or huggingface"
# Construct model path # Construct model path
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
# Load model from model_path if it exists there # Load model from model_path if it exists there
model_type_class = get_class_by_name(model_type) if isinstance(model_type, str) else model_type
if model_path is not None and resolve_absolute_path(model_path).exists(): if model_path is not None and resolve_absolute_path(model_path).exists():
model = model_type(get_absolute_path(model_path), device=device) model = model_type_class(get_absolute_path(model_path), device=device)
# Else load the model from the model_name # Else load the model from the model_name
else: else:
model = model_type(model_name, device=device) model = model_type_class(model_name, device=device)
if model_path is not None: if model_path is not None:
model.save(model_path) model.save(model_path)
@ -66,6 +68,12 @@ def is_pyinstaller_app():
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS') return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
def get_class_by_name(name: str) -> object:
"Returns the class object from name string"
module_name, class_name = name.rsplit('.', 1)
return getattr(import_module(module_name), class_name)
class LRU(OrderedDict): class LRU(OrderedDict):
def __init__(self, *args, capacity=128, **kwargs): def __init__(self, *args, capacity=128, **kwargs):
self.capacity = capacity self.capacity = capacity

View file

@ -50,10 +50,12 @@ class ContentConfig(ConfigBase):
class TextSearchConfig(ConfigBase): class TextSearchConfig(ConfigBase):
encoder: str encoder: str
cross_encoder: str cross_encoder: str
encoder_type: Optional[str]
model_directory: Optional[Path] model_directory: Optional[Path]
class ImageSearchConfig(ConfigBase): class ImageSearchConfig(ConfigBase):
encoder: str encoder: str
encoder_type: Optional[str]
model_directory: Optional[Path] model_directory: Optional[Path]
class SearchConfig(ConfigBase): class SearchConfig(ConfigBase):