mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Merge pull request #896 from khoj-ai/features/add-support-for-custom-confidence
Add support for custom search model-specific thresholds
This commit is contained in:
commit
af4e9988c4
5 changed files with 29 additions and 6 deletions
|
@ -0,0 +1,17 @@
|
|||
# Generated by Django 5.0.7 on 2024-08-24 18:19
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0058_alter_chatmodeloptions_chat_model"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="searchmodelconfig",
|
||||
name="bi_encoder_confidence_threshold",
|
||||
field=models.FloatField(default=0.18),
|
||||
),
|
||||
]
|
|
@ -270,6 +270,8 @@ class SearchModelConfig(BaseModel):
|
|||
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
|
||||
cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
# The confidence threshold of the bi_encoder model to consider the embeddings as relevant
|
||||
bi_encoder_confidence_threshold = models.FloatField(default=0.18)
|
||||
|
||||
|
||||
class TextToImageModelConfig(BaseModel):
|
||||
|
|
|
@ -83,7 +83,7 @@ async def search(
|
|||
n=n,
|
||||
t=t,
|
||||
r=r,
|
||||
max_distance=max_distance,
|
||||
max_distance=max_distance or math.inf,
|
||||
dedupe=dedupe,
|
||||
)
|
||||
|
||||
|
@ -117,7 +117,6 @@ async def execute_search(
|
|||
# initialize variables
|
||||
user_query = q.strip()
|
||||
results_count = n or 5
|
||||
max_distance = max_distance or math.inf
|
||||
search_futures: List[concurrent.futures.Future] = []
|
||||
|
||||
# return cached results, if available
|
||||
|
|
|
@ -524,7 +524,7 @@ async def chat(
|
|||
common: CommonQueryParams,
|
||||
q: str,
|
||||
n: int = 7,
|
||||
d: float = 0.18,
|
||||
d: float = None,
|
||||
stream: Optional[bool] = False,
|
||||
title: Optional[str] = None,
|
||||
conversation_id: Optional[int] = None,
|
||||
|
@ -764,7 +764,7 @@ async def chat(
|
|||
meta_log,
|
||||
q,
|
||||
(n or 7),
|
||||
(d or 0.18),
|
||||
d,
|
||||
conversation_id,
|
||||
conversation_commands,
|
||||
location,
|
||||
|
|
|
@ -100,18 +100,23 @@ async def query(
|
|||
raw_query: str,
|
||||
type: SearchType = SearchType.All,
|
||||
question_embedding: Union[torch.Tensor, None] = None,
|
||||
max_distance: float = math.inf,
|
||||
max_distance: float = None,
|
||||
) -> Tuple[List[dict], List[Entry]]:
|
||||
"Search for entries that answer the query"
|
||||
|
||||
file_type = search_type_to_embeddings_type[type.value]
|
||||
|
||||
query = raw_query
|
||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||
if not max_distance:
|
||||
if search_model.bi_encoder_confidence_threshold:
|
||||
max_distance = search_model.bi_encoder_confidence_threshold
|
||||
else:
|
||||
max_distance = math.inf
|
||||
|
||||
# Encode the query using the bi-encoder
|
||||
if question_embedding is None:
|
||||
with timer("Query Encode Time", logger, state.device):
|
||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||
question_embedding = state.embeddings_model[search_model.name].embed_query(query)
|
||||
|
||||
# Find relevant entries for the query
|
||||
|
|
Loading…
Add table
Reference in a new issue