Merge pull request #896 from khoj-ai/features/add-support-for-custom-confidence

Add support for custom search model-specific thresholds
This commit is contained in:
sabaimran 2024-08-24 20:32:41 -07:00 committed by GitHub
commit af4e9988c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 29 additions and 6 deletions

View file

@ -0,0 +1,17 @@
# Generated by Django 5.0.7 on 2024-08-24 18:19
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0058_alter_chatmodeloptions_chat_model"),
]
operations = [
migrations.AddField(
model_name="searchmodelconfig",
name="bi_encoder_confidence_threshold",
field=models.FloatField(default=0.18),
),
]

View file

@ -270,6 +270,8 @@ class SearchModelConfig(BaseModel):
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
# The confidence threshold of the bi_encoder model to consider the embeddings as relevant
bi_encoder_confidence_threshold = models.FloatField(default=0.18)
class TextToImageModelConfig(BaseModel):

View file

@ -83,7 +83,7 @@ async def search(
n=n,
t=t,
r=r,
max_distance=max_distance,
max_distance=max_distance or math.inf,
dedupe=dedupe,
)
@ -117,7 +117,6 @@ async def execute_search(
# initialize variables
user_query = q.strip()
results_count = n or 5
max_distance = max_distance or math.inf
search_futures: List[concurrent.futures.Future] = []
# return cached results, if available

View file

@ -524,7 +524,7 @@ async def chat(
common: CommonQueryParams,
q: str,
n: int = 7,
d: float = 0.18,
d: float = None,
stream: Optional[bool] = False,
title: Optional[str] = None,
conversation_id: Optional[int] = None,
@ -764,7 +764,7 @@ async def chat(
meta_log,
q,
(n or 7),
(d or 0.18),
d,
conversation_id,
conversation_commands,
location,

View file

@ -100,18 +100,23 @@ async def query(
raw_query: str,
type: SearchType = SearchType.All,
question_embedding: Union[torch.Tensor, None] = None,
max_distance: float = math.inf,
max_distance: float = None,
) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query"
file_type = search_type_to_embeddings_type[type.value]
query = raw_query
search_model = await sync_to_async(get_user_search_model_or_default)(user)
if not max_distance:
if search_model.bi_encoder_confidence_threshold:
max_distance = search_model.bi_encoder_confidence_threshold
else:
max_distance = math.inf
# Encode the query using the bi-encoder
if question_embedding is None:
with timer("Query Encode Time", logger, state.device):
search_model = await sync_to_async(get_user_search_model_or_default)(user)
question_embedding = state.embeddings_model[search_model.name].embed_query(query)
# Find relevant entries for the query