Add model_config for cross-encoder model (#885) from aam-at/feature/crossencoder_model_config

Add `model_config' for the cross-encoder model, so the server admin can use models which require the `trust_remote_code' argument to run locally
This commit is contained in:
Debanjum 2024-08-16 07:32:19 -07:00 committed by GitHub
commit 0b568e204e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 23 additions and 1 deletions

View file

@ -263,6 +263,7 @@ def configure_server(
model.cross_encoder,
model.cross_encoder_inference_endpoint,
model.cross_encoder_inference_endpoint_api_key,
model_kwargs=model.cross_encoder_model_config,
)
}
)

View file

@ -0,0 +1,17 @@
# Generated by Django 5.0.7 on 2024-08-07 09:12
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0055_alter_agent_style_icon"),
]
operations = [
migrations.AddField(
model_name="searchmodelconfig",
name="cross_encoder_model_config",
field=models.JSONField(blank=True, default=dict),
),
]

View file

@ -259,6 +259,8 @@ class SearchModelConfig(BaseModel):
bi_encoder_docs_encode_config = models.JSONField(default=dict, blank=True)
# Cross-encoder model of sentence-transformer type to load from HuggingFace
cross_encoder = models.CharField(max_length=200, default="mixedbread-ai/mxbai-rerank-xsmall-v1")
# Config passed to the cross-encoder model constructor. E.g. device="cuda:0", trust_remote_server=True etc.
cross_encoder_model_config = models.JSONField(default=dict, blank=True)
# Inference server API endpoint to use for embeddings inference. Bi-encoder model should be hosted on this server
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server

View file

@ -95,11 +95,13 @@ class CrossEncoderModel:
model_name: str = "mixedbread-ai/mxbai-rerank-xsmall-v1",
cross_encoder_inference_endpoint: str = None,
cross_encoder_inference_endpoint_api_key: str = None,
model_kwargs: dict = {},
):
self.model_name = model_name
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
self.inference_endpoint = cross_encoder_inference_endpoint
self.api_key = cross_encoder_inference_endpoint_api_key
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, **self.model_kwargs)
def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None