mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Only get text search results above confidence threshold via API
- During the migration, the confidence score stopped being used. It was being passed down from API to some point and went unused - Remove score thresholding for images as image search confidence score different from text search model distance score - Default score threshold of 0.15 is experimentally determined by manually looking at search results vs distance for a few queries - Use distance instead of confidence as metric for search result quality Previously we'd moved text search to a distance metric from a confidence score. Now convert even cross encoder, image search scores to distance metric for consistent results sorting
This commit is contained in:
parent
e44e6df221
commit
941c7f23a3
4 changed files with 37 additions and 23 deletions
|
@ -1,3 +1,4 @@
|
|||
import math
|
||||
from typing import Optional, Type, TypeVar, List
|
||||
from datetime import date, datetime, timedelta
|
||||
import secrets
|
||||
|
@ -437,12 +438,19 @@ class EntryAdapters:
|
|||
|
||||
@staticmethod
|
||||
def search_with_embeddings(
|
||||
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
|
||||
user: KhojUser,
|
||||
embeddings: Tensor,
|
||||
max_results: int = 10,
|
||||
file_type_filter: str = None,
|
||||
raw_query: str = None,
|
||||
max_distance: float = math.inf,
|
||||
):
|
||||
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
|
||||
relevant_entries = relevant_entries.filter(user=user).annotate(
|
||||
distance=CosineDistance("embeddings", embeddings)
|
||||
)
|
||||
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
|
||||
|
||||
if file_type_filter:
|
||||
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
||||
relevant_entries = relevant_entries.order_by("distance")
|
||||
|
|
|
@ -356,7 +356,7 @@ async def search(
|
|||
n: Optional[int] = 5,
|
||||
t: Optional[SearchType] = SearchType.All,
|
||||
r: Optional[bool] = False,
|
||||
score_threshold: Optional[Union[float, None]] = None,
|
||||
max_distance: Optional[Union[float, None]] = None,
|
||||
dedupe: Optional[bool] = True,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
|
@ -375,12 +375,12 @@ async def search(
|
|||
# initialize variables
|
||||
user_query = q.strip()
|
||||
results_count = n or 5
|
||||
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
||||
max_distance = max_distance if max_distance is not None else math.inf
|
||||
search_futures: List[concurrent.futures.Future] = []
|
||||
|
||||
# return cached results, if available
|
||||
if user:
|
||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
|
||||
if query_cache_key in state.query_cache[user.uuid]:
|
||||
logger.debug(f"Return response from query cache")
|
||||
return state.query_cache[user.uuid][query_cache_key]
|
||||
|
@ -418,7 +418,7 @@ async def search(
|
|||
t,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
max_distance=max_distance,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -431,7 +431,6 @@ async def search(
|
|||
results_count,
|
||||
state.search_models.image_search,
|
||||
state.content_index.image,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -454,11 +453,10 @@ async def search(
|
|||
# Collate results
|
||||
results += text_search.collate_results(hits, dedupe=dedupe)
|
||||
|
||||
if r:
|
||||
results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
|
||||
else:
|
||||
# Sort results across all content types and take top results
|
||||
results = sorted(results, key=lambda x: float(x.score))[:results_count]
|
||||
results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[
|
||||
:results_count
|
||||
]
|
||||
|
||||
# Cache results
|
||||
if user:
|
||||
|
@ -583,6 +581,7 @@ async def chat(
|
|||
request: Request,
|
||||
q: str,
|
||||
n: Optional[int] = 5,
|
||||
d: Optional[float] = 0.15,
|
||||
client: Optional[str] = None,
|
||||
stream: Optional[bool] = False,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
|
@ -599,7 +598,7 @@ async def chat(
|
|||
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
request, meta_log, q, (n or 5), conversation_command
|
||||
request, meta_log, q, (n or 5), (d or math.inf), conversation_command
|
||||
)
|
||||
|
||||
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
|
||||
|
@ -663,6 +662,7 @@ async def extract_references_and_questions(
|
|||
meta_log: dict,
|
||||
q: str,
|
||||
n: int,
|
||||
d: float,
|
||||
conversation_type: ConversationCommand = ConversationCommand.Default,
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
@ -723,7 +723,7 @@ async def extract_references_and_questions(
|
|||
request=request,
|
||||
n=n_items,
|
||||
r=True,
|
||||
score_threshold=-5.0,
|
||||
max_distance=d,
|
||||
dedupe=False,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -146,7 +146,7 @@ def extract_metadata(image_name):
|
|||
|
||||
|
||||
async def query(
|
||||
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
|
||||
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf
|
||||
):
|
||||
# Set query to image content if query is of form file:/path/to/file.png
|
||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||
|
@ -167,7 +167,8 @@ async def query(
|
|||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||
with timer("Search Time", logger):
|
||||
image_hits = {
|
||||
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
|
||||
# Map scores to distance metric by multiplying by -1
|
||||
result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]}
|
||||
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
|
||||
}
|
||||
|
||||
|
@ -204,7 +205,7 @@ async def query(
|
|||
]
|
||||
|
||||
# Filter results by score threshold
|
||||
hits = [hit for hit in hits if hit["image_score"] >= score_threshold]
|
||||
hits = [hit for hit in hits if hit["image_score"] <= score_threshold]
|
||||
|
||||
# Sort the images based on their combined metadata, image scores
|
||||
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
|
||||
|
|
|
@ -105,7 +105,7 @@ async def query(
|
|||
type: SearchType = SearchType.All,
|
||||
question_embedding: Union[torch.Tensor, None] = None,
|
||||
rank_results: bool = False,
|
||||
score_threshold: float = -math.inf,
|
||||
max_distance: float = math.inf,
|
||||
) -> Tuple[List[dict], List[Entry]]:
|
||||
"Search for entries that answer the query"
|
||||
|
||||
|
@ -127,6 +127,7 @@ async def query(
|
|||
max_results=top_k,
|
||||
file_type_filter=file_type,
|
||||
raw_query=raw_query,
|
||||
max_distance=max_distance,
|
||||
).all()
|
||||
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
|
||||
|
||||
|
@ -177,12 +178,16 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
|
|||
)
|
||||
|
||||
|
||||
def rerank_and_sort_results(hits, query):
|
||||
def rerank_and_sort_results(hits, query, rank_results):
|
||||
# If we have more than one result and reranking is enabled
|
||||
rank_results = rank_results and len(list(hits)) > 1
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
hits = cross_encoder_score(query, hits)
|
||||
if rank_results:
|
||||
hits = cross_encoder_score(query, hits)
|
||||
|
||||
# Sort results by cross-encoder score followed by bi-encoder score
|
||||
hits = sort_results(rank_results=True, hits=hits)
|
||||
hits = sort_results(rank_results=rank_results, hits=hits)
|
||||
|
||||
return hits
|
||||
|
||||
|
@ -217,9 +222,9 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
|
|||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||
cross_scores = state.cross_encoder_model.predict(query, hits)
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
# Convert cross-encoder scores to distances and pass in hits for reranking
|
||||
for idx in range(len(cross_scores)):
|
||||
hits[idx]["cross_score"] = cross_scores[idx]
|
||||
hits[idx]["cross_score"] = -1 * cross_scores[idx]
|
||||
|
||||
return hits
|
||||
|
||||
|
@ -227,7 +232,7 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
|
|||
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
|
||||
"""Order results by cross-encoder score followed by bi-encoder score"""
|
||||
with timer("Rank Time", logger, state.device):
|
||||
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
|
||||
hits.sort(key=lambda x: x["score"]) # sort by bi-encoder score
|
||||
if rank_results:
|
||||
hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score
|
||||
hits.sort(key=lambda x: x["cross_score"]) # sort by cross-encoder score
|
||||
return hits
|
||||
|
|
Loading…
Add table
Reference in a new issue