Allow custom inference endpoint for the crossencoder model (#616)

* Add support for custom inference endpoints for the cross encoder model
- Since there's not a good out of the box solution, I've deployed a custom model/handler via huggingface to support this use case.
* Use langchain.community for pdf, openai chat modules
* Add an explicit stipulation that the api endpoint for crossencoder inference should be for huggingface for now
This commit is contained in:
sabaimran 2024-01-17 20:32:12 -08:00 committed by GitHub
parent 08012c71b1
commit e9e49ea098
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 54 additions and 4 deletions

View file

@ -156,7 +156,15 @@ def configure_server(
)
}
)
state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)})
state.cross_encoder_model.update(
{
model.name: CrossEncoderModel(
model.cross_encoder,
model.cross_encoder_inference_endpoint,
model.cross_encoder_inference_endpoint_api_key,
)
}
)
state.SearchType = configure_search_types()
state.search_models = configure_search(state.search_models, state.config.search_type)

View file

@ -0,0 +1,22 @@
# Generated by Django 4.2.7 on 2024-01-17 04:21
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0025_searchmodelconfig_embeddings_inference_endpoint_and_more"),
]
operations = [
migrations.AddField(
model_name="searchmodelconfig",
name="cross_encoder_inference_endpoint",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
migrations.AddField(
model_name="searchmodelconfig",
name="cross_encoder_inference_endpoint_api_key",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
]

View file

@ -112,6 +112,8 @@ class SearchModelConfig(BaseModel):
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
class TextToImageModelConfig(BaseModel):

View file

@ -4,7 +4,7 @@ import os
from datetime import datetime
from typing import List, Tuple
from langchain.document_loaders import PyMuPDFLoader
from langchain_community.document_loaders import PyMuPDFLoader
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser

View file

@ -6,7 +6,7 @@ from typing import Any
import openai
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain_community.chat_models import ChatOpenAI
from tenacity import (
before_sleep_log,
retry,

View file

@ -68,11 +68,29 @@ class EmbeddingsModel:
class CrossEncoderModel:
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
def __init__(
self,
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
cross_encoder_inference_endpoint: str = None,
cross_encoder_inference_endpoint_api_key: str = None,
):
self.model_name = model_name
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
self.inference_endpoint = cross_encoder_inference_endpoint
self.api_key = cross_encoder_inference_endpoint_api_key
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
if (
self.api_key is not None
and self.inference_endpoint is not None
and "huggingface" in self.inference_endpoint
):
target_url = f"{self.inference_endpoint}"
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
response = requests.post(target_url, json=payload, headers=headers)
return response.json()["scores"]
cross_inp = [[query, hit.additional[key]] for hit in hits]
cross_scores = self.cross_encoder_model.predict(cross_inp, activation_fct=nn.Sigmoid())
return cross_scores