mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Make type of encoder to use for embeddings configurable via khoj.yml
- Previously `model_type' was set in the setup of each `search_type' - All encoders were of type `SentenceTransformer' - All cross_encoders were of type `CrossEncoder' - Now `encoder-type' can be configured via the new `encoder_type' field in `TextSearchConfig' under `search-type` in `khoj.yml`. - All the specified `encoder-type' class needs is an `encode' method that takes entries and returns embedding vectors
This commit is contained in:
parent
fa92adcf0d
commit
2fe37a090f
4 changed files with 15 additions and 5 deletions
|
@ -36,7 +36,7 @@ def initialize_model(search_config: ImageSearchConfig):
|
|||
encoder = load_model(
|
||||
model_dir = search_config.model_directory,
|
||||
model_name = search_config.encoder,
|
||||
model_type = SentenceTransformer)
|
||||
model_type = search_config.encoder_type or SentenceTransformer)
|
||||
|
||||
return encoder
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ def initialize_model(search_config: TextSearchConfig):
|
|||
bi_encoder = load_model(
|
||||
model_dir = search_config.model_directory,
|
||||
model_name = search_config.encoder,
|
||||
model_type = SentenceTransformer,
|
||||
model_type = search_config.encoder_type or SentenceTransformer,
|
||||
device=f'{state.device}')
|
||||
|
||||
# The cross-encoder re-ranks the results to improve quality
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Standard Packages
|
||||
from pathlib import Path
|
||||
from importlib import import_module
|
||||
import sys
|
||||
from os.path import join
|
||||
from collections import OrderedDict
|
||||
|
@ -44,17 +45,18 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
|
|||
return merged_dict
|
||||
|
||||
|
||||
def load_model(model_name, model_dir, model_type, device:str=None):
|
||||
def load_model(model_name: str, model_type, model_dir=None, device:str=None):
|
||||
"Load model from disk or huggingface"
|
||||
# Construct model path
|
||||
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
|
||||
|
||||
# Load model from model_path if it exists there
|
||||
model_type_class = get_class_by_name(model_type) if isinstance(model_type, str) else model_type
|
||||
if model_path is not None and resolve_absolute_path(model_path).exists():
|
||||
model = model_type(get_absolute_path(model_path), device=device)
|
||||
model = model_type_class(get_absolute_path(model_path), device=device)
|
||||
# Else load the model from the model_name
|
||||
else:
|
||||
model = model_type(model_name, device=device)
|
||||
model = model_type_class(model_name, device=device)
|
||||
if model_path is not None:
|
||||
model.save(model_path)
|
||||
|
||||
|
@ -66,6 +68,12 @@ def is_pyinstaller_app():
|
|||
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
|
||||
|
||||
|
||||
def get_class_by_name(name: str) -> object:
|
||||
"Returns the class object from name string"
|
||||
module_name, class_name = name.rsplit('.', 1)
|
||||
return getattr(import_module(module_name), class_name)
|
||||
|
||||
|
||||
class LRU(OrderedDict):
|
||||
def __init__(self, *args, capacity=128, **kwargs):
|
||||
self.capacity = capacity
|
||||
|
|
|
@ -50,10 +50,12 @@ class ContentConfig(ConfigBase):
|
|||
class TextSearchConfig(ConfigBase):
|
||||
encoder: str
|
||||
cross_encoder: str
|
||||
encoder_type: Optional[str]
|
||||
model_directory: Optional[Path]
|
||||
|
||||
class ImageSearchConfig(ConfigBase):
|
||||
encoder: str
|
||||
encoder_type: Optional[str]
|
||||
model_directory: Optional[Path]
|
||||
|
||||
class SearchConfig(ConfigBase):
|
||||
|
|
Loading…
Reference in a new issue