mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Add model_config for cross-encoder model (#885) from aam-at/feature/crossencoder_model_config
Add `model_config' for the cross-encoder model, so the server admin can use models which require the `trust_remote_code' argument to run locally
This commit is contained in:
commit
0b568e204e
4 changed files with 23 additions and 1 deletions
|
@ -263,6 +263,7 @@ def configure_server(
|
|||
model.cross_encoder,
|
||||
model.cross_encoder_inference_endpoint,
|
||||
model.cross_encoder_inference_endpoint_api_key,
|
||||
model_kwargs=model.cross_encoder_model_config,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# Generated by Django 5.0.7 on 2024-08-07 09:12
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0055_alter_agent_style_icon"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="searchmodelconfig",
|
||||
name="cross_encoder_model_config",
|
||||
field=models.JSONField(blank=True, default=dict),
|
||||
),
|
||||
]
|
|
@ -259,6 +259,8 @@ class SearchModelConfig(BaseModel):
|
|||
bi_encoder_docs_encode_config = models.JSONField(default=dict, blank=True)
|
||||
# Cross-encoder model of sentence-transformer type to load from HuggingFace
|
||||
cross_encoder = models.CharField(max_length=200, default="mixedbread-ai/mxbai-rerank-xsmall-v1")
|
||||
# Config passed to the cross-encoder model constructor. E.g. device="cuda:0", trust_remote_server=True etc.
|
||||
cross_encoder_model_config = models.JSONField(default=dict, blank=True)
|
||||
# Inference server API endpoint to use for embeddings inference. Bi-encoder model should be hosted on this server
|
||||
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
|
||||
|
|
|
@ -95,11 +95,13 @@ class CrossEncoderModel:
|
|||
model_name: str = "mixedbread-ai/mxbai-rerank-xsmall-v1",
|
||||
cross_encoder_inference_endpoint: str = None,
|
||||
cross_encoder_inference_endpoint_api_key: str = None,
|
||||
model_kwargs: dict = {},
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
|
||||
self.inference_endpoint = cross_encoder_inference_endpoint
|
||||
self.api_key = cross_encoder_inference_endpoint_api_key
|
||||
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
|
||||
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, **self.model_kwargs)
|
||||
|
||||
def inference_server_enabled(self) -> bool:
|
||||
return self.api_key is not None and self.inference_endpoint is not None
|
||||
|
|
Loading…
Add table
Reference in a new issue