mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-24 07:55:07 +01:00
Make type of encoder to use for embeddings configurable via khoj.yml
- Previously `model_type' was set in the setup of each `search_type' - All encoders were of type `SentenceTransformer' - All cross_encoders were of type `CrossEncoder' - Now `encoder-type' can be configured via the new `encoder_type' field in `TextSearchConfig' under `search-type` in `khoj.yml`. - All the specified `encoder-type' class needs is an `encode' method that takes entries and returns embedding vectors
This commit is contained in:
parent
fa92adcf0d
commit
2fe37a090f
4 changed files with 15 additions and 5 deletions
|
@ -36,7 +36,7 @@ def initialize_model(search_config: ImageSearchConfig):
|
||||||
encoder = load_model(
|
encoder = load_model(
|
||||||
model_dir = search_config.model_directory,
|
model_dir = search_config.model_directory,
|
||||||
model_name = search_config.encoder,
|
model_name = search_config.encoder,
|
||||||
model_type = SentenceTransformer)
|
model_type = search_config.encoder_type or SentenceTransformer)
|
||||||
|
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ def initialize_model(search_config: TextSearchConfig):
|
||||||
bi_encoder = load_model(
|
bi_encoder = load_model(
|
||||||
model_dir = search_config.model_directory,
|
model_dir = search_config.model_directory,
|
||||||
model_name = search_config.encoder,
|
model_name = search_config.encoder,
|
||||||
model_type = SentenceTransformer,
|
model_type = search_config.encoder_type or SentenceTransformer,
|
||||||
device=f'{state.device}')
|
device=f'{state.device}')
|
||||||
|
|
||||||
# The cross-encoder re-ranks the results to improve quality
|
# The cross-encoder re-ranks the results to improve quality
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from importlib import import_module
|
||||||
import sys
|
import sys
|
||||||
from os.path import join
|
from os.path import join
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
@ -44,17 +45,18 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
|
||||||
return merged_dict
|
return merged_dict
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_name, model_dir, model_type, device:str=None):
|
def load_model(model_name: str, model_type, model_dir=None, device:str=None):
|
||||||
"Load model from disk or huggingface"
|
"Load model from disk or huggingface"
|
||||||
# Construct model path
|
# Construct model path
|
||||||
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
|
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
|
||||||
|
|
||||||
# Load model from model_path if it exists there
|
# Load model from model_path if it exists there
|
||||||
|
model_type_class = get_class_by_name(model_type) if isinstance(model_type, str) else model_type
|
||||||
if model_path is not None and resolve_absolute_path(model_path).exists():
|
if model_path is not None and resolve_absolute_path(model_path).exists():
|
||||||
model = model_type(get_absolute_path(model_path), device=device)
|
model = model_type_class(get_absolute_path(model_path), device=device)
|
||||||
# Else load the model from the model_name
|
# Else load the model from the model_name
|
||||||
else:
|
else:
|
||||||
model = model_type(model_name, device=device)
|
model = model_type_class(model_name, device=device)
|
||||||
if model_path is not None:
|
if model_path is not None:
|
||||||
model.save(model_path)
|
model.save(model_path)
|
||||||
|
|
||||||
|
@ -66,6 +68,12 @@ def is_pyinstaller_app():
|
||||||
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
|
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_by_name(name: str) -> object:
|
||||||
|
"Returns the class object from name string"
|
||||||
|
module_name, class_name = name.rsplit('.', 1)
|
||||||
|
return getattr(import_module(module_name), class_name)
|
||||||
|
|
||||||
|
|
||||||
class LRU(OrderedDict):
|
class LRU(OrderedDict):
|
||||||
def __init__(self, *args, capacity=128, **kwargs):
|
def __init__(self, *args, capacity=128, **kwargs):
|
||||||
self.capacity = capacity
|
self.capacity = capacity
|
||||||
|
|
|
@ -50,10 +50,12 @@ class ContentConfig(ConfigBase):
|
||||||
class TextSearchConfig(ConfigBase):
|
class TextSearchConfig(ConfigBase):
|
||||||
encoder: str
|
encoder: str
|
||||||
cross_encoder: str
|
cross_encoder: str
|
||||||
|
encoder_type: Optional[str]
|
||||||
model_directory: Optional[Path]
|
model_directory: Optional[Path]
|
||||||
|
|
||||||
class ImageSearchConfig(ConfigBase):
|
class ImageSearchConfig(ConfigBase):
|
||||||
encoder: str
|
encoder: str
|
||||||
|
encoder_type: Optional[str]
|
||||||
model_directory: Optional[Path]
|
model_directory: Optional[Path]
|
||||||
|
|
||||||
class SearchConfig(ConfigBase):
|
class SearchConfig(ConfigBase):
|
||||||
|
|
Loading…
Reference in a new issue