mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Allow custom inference endpoint for the crossencoder model (#616)
* Add support for custom inference endpoints for the cross encoder model - Since there's not a good out of the box solution, I've deployed a custom model/handler via huggingface to support this use case. * Use langchain.community for pdf, openai chat modules * Add an explicit stipulation that the api endpoint for crossencoder inference should be for huggingface for now
This commit is contained in:
parent
08012c71b1
commit
e9e49ea098
6 changed files with 54 additions and 4 deletions
|
@ -156,7 +156,15 @@ def configure_server(
|
|||
)
|
||||
}
|
||||
)
|
||||
state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)})
|
||||
state.cross_encoder_model.update(
|
||||
{
|
||||
model.name: CrossEncoderModel(
|
||||
model.cross_encoder,
|
||||
model.cross_encoder_inference_endpoint,
|
||||
model.cross_encoder_inference_endpoint_api_key,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
state.SearchType = configure_search_types()
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# Generated by Django 4.2.7 on 2024-01-17 04:21
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0025_searchmodelconfig_embeddings_inference_endpoint_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="searchmodelconfig",
|
||||
name="cross_encoder_inference_endpoint",
|
||||
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="searchmodelconfig",
|
||||
name="cross_encoder_inference_endpoint_api_key",
|
||||
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||
),
|
||||
]
|
|
@ -112,6 +112,8 @@ class SearchModelConfig(BaseModel):
|
|||
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class TextToImageModelConfig(BaseModel):
|
||||
|
|
|
@ -4,7 +4,7 @@ import os
|
|||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
|
||||
from langchain.document_loaders import PyMuPDFLoader
|
||||
from langchain_community.document_loaders import PyMuPDFLoader
|
||||
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import KhojUser
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import Any
|
|||
import openai
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
|
|
|
@ -68,11 +68,29 @@ class EmbeddingsModel:
|
|||
|
||||
|
||||
class CrossEncoderModel:
|
||||
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
cross_encoder_inference_endpoint: str = None,
|
||||
cross_encoder_inference_endpoint_api_key: str = None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
|
||||
self.inference_endpoint = cross_encoder_inference_endpoint
|
||||
self.api_key = cross_encoder_inference_endpoint_api_key
|
||||
|
||||
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
|
||||
if (
|
||||
self.api_key is not None
|
||||
and self.inference_endpoint is not None
|
||||
and "huggingface" in self.inference_endpoint
|
||||
):
|
||||
target_url = f"{self.inference_endpoint}"
|
||||
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
response = requests.post(target_url, json=payload, headers=headers)
|
||||
return response.json()["scores"]
|
||||
|
||||
cross_inp = [[query, hit.additional[key]] for hit in hits]
|
||||
cross_scores = self.cross_encoder_model.predict(cross_inp, activation_fct=nn.Sigmoid())
|
||||
return cross_scores
|
||||
|
|
Loading…
Reference in a new issue