mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 16:14:21 +00:00
Allow custom inference endpoint for the crossencoder model (#616)
* Add support for custom inference endpoints for the cross encoder model - Since there's not a good out of the box solution, I've deployed a custom model/handler via huggingface to support this use case. * Use langchain.community for pdf, openai chat modules * Add an explicit stipulation that the api endpoint for crossencoder inference should be for huggingface for now
This commit is contained in:
parent
08012c71b1
commit
e9e49ea098
6 changed files with 54 additions and 4 deletions
|
@ -156,7 +156,15 @@ def configure_server(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)})
|
state.cross_encoder_model.update(
|
||||||
|
{
|
||||||
|
model.name: CrossEncoderModel(
|
||||||
|
model.cross_encoder,
|
||||||
|
model.cross_encoder_inference_endpoint,
|
||||||
|
model.cross_encoder_inference_endpoint_api_key,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
state.SearchType = configure_search_types()
|
state.SearchType = configure_search_types()
|
||||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
# Generated by Django 4.2.7 on 2024-01-17 04:21
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0025_searchmodelconfig_embeddings_inference_endpoint_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="searchmodelconfig",
|
||||||
|
name="cross_encoder_inference_endpoint",
|
||||||
|
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="searchmodelconfig",
|
||||||
|
name="cross_encoder_inference_endpoint_api_key",
|
||||||
|
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||||
|
),
|
||||||
|
]
|
|
@ -112,6 +112,8 @@ class SearchModelConfig(BaseModel):
|
||||||
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||||
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
|
||||||
class TextToImageModelConfig(BaseModel):
|
class TextToImageModelConfig(BaseModel):
|
||||||
|
|
|
@ -4,7 +4,7 @@ import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from langchain.document_loaders import PyMuPDFLoader
|
from langchain_community.document_loaders import PyMuPDFLoader
|
||||||
|
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import KhojUser
|
from khoj.database.models import KhojUser
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import Any
|
||||||
import openai
|
import openai
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
retry,
|
retry,
|
||||||
|
|
|
@ -68,11 +68,29 @@ class EmbeddingsModel:
|
||||||
|
|
||||||
|
|
||||||
class CrossEncoderModel:
|
class CrossEncoderModel:
|
||||||
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||||
|
cross_encoder_inference_endpoint: str = None,
|
||||||
|
cross_encoder_inference_endpoint_api_key: str = None,
|
||||||
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
|
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
|
||||||
|
self.inference_endpoint = cross_encoder_inference_endpoint
|
||||||
|
self.api_key = cross_encoder_inference_endpoint_api_key
|
||||||
|
|
||||||
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
|
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
|
||||||
|
if (
|
||||||
|
self.api_key is not None
|
||||||
|
and self.inference_endpoint is not None
|
||||||
|
and "huggingface" in self.inference_endpoint
|
||||||
|
):
|
||||||
|
target_url = f"{self.inference_endpoint}"
|
||||||
|
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
|
||||||
|
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||||
|
response = requests.post(target_url, json=payload, headers=headers)
|
||||||
|
return response.json()["scores"]
|
||||||
|
|
||||||
cross_inp = [[query, hit.additional[key]] for hit in hits]
|
cross_inp = [[query, hit.additional[key]] for hit in hits]
|
||||||
cross_scores = self.cross_encoder_model.predict(cross_inp, activation_fct=nn.Sigmoid())
|
cross_scores = self.cross_encoder_model.predict(cross_inp, activation_fct=nn.Sigmoid())
|
||||||
return cross_scores
|
return cross_scores
|
||||||
|
|
Loading…
Add table
Reference in a new issue