Only get text search results above confidence threshold via API

- During the migration, the confidence score stopped being used. It
  was being passed down from API to some point and went unused

- Remove score thresholding for images as image search confidence
  score different from text search model distance score

- Default score threshold of 0.15 is experimentally determined by
  manually looking at search results vs distance for a few queries

- Use distance instead of confidence as metric for search result quality
  Previously we'd moved text search to a distance metric from a
  confidence score.

  Now convert even cross encoder, image search scores to distance metric
  for consistent results sorting
This commit is contained in:
Debanjum Singh Solanky 2023-11-11 03:30:35 -08:00
parent e44e6df221
commit 941c7f23a3
4 changed files with 37 additions and 23 deletions

View file

@ -1,3 +1,4 @@
import math
from typing import Optional, Type, TypeVar, List
from datetime import date, datetime, timedelta
import secrets
@ -437,12 +438,19 @@ class EntryAdapters:
@staticmethod
def search_with_embeddings(
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
user: KhojUser,
embeddings: Tensor,
max_results: int = 10,
file_type_filter: str = None,
raw_query: str = None,
max_distance: float = math.inf,
):
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_entries = relevant_entries.filter(user=user).annotate(
distance=CosineDistance("embeddings", embeddings)
)
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
if file_type_filter:
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
relevant_entries = relevant_entries.order_by("distance")

View file

@ -356,7 +356,7 @@ async def search(
n: Optional[int] = 5,
t: Optional[SearchType] = SearchType.All,
r: Optional[bool] = False,
score_threshold: Optional[Union[float, None]] = None,
max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
@ -375,12 +375,12 @@ async def search(
# initialize variables
user_query = q.strip()
results_count = n or 5
score_threshold = score_threshold if score_threshold is not None else -math.inf
max_distance = max_distance if max_distance is not None else math.inf
search_futures: List[concurrent.futures.Future] = []
# return cached results, if available
if user:
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
if query_cache_key in state.query_cache[user.uuid]:
logger.debug(f"Return response from query cache")
return state.query_cache[user.uuid][query_cache_key]
@ -418,7 +418,7 @@ async def search(
t,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
max_distance=max_distance,
)
]
@ -431,7 +431,6 @@ async def search(
results_count,
state.search_models.image_search,
state.content_index.image,
score_threshold=score_threshold,
)
]
@ -454,11 +453,10 @@ async def search(
# Collate results
results += text_search.collate_results(hits, dedupe=dedupe)
if r:
results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
else:
# Sort results across all content types and take top results
results = sorted(results, key=lambda x: float(x.score))[:results_count]
results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[
:results_count
]
# Cache results
if user:
@ -583,6 +581,7 @@ async def chat(
request: Request,
q: str,
n: Optional[int] = 5,
d: Optional[float] = 0.15,
client: Optional[str] = None,
stream: Optional[bool] = False,
user_agent: Optional[str] = Header(None),
@ -599,7 +598,7 @@ async def chat(
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, meta_log, q, (n or 5), conversation_command
request, meta_log, q, (n or 5), (d or math.inf), conversation_command
)
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
@ -663,6 +662,7 @@ async def extract_references_and_questions(
meta_log: dict,
q: str,
n: int,
d: float,
conversation_type: ConversationCommand = ConversationCommand.Default,
):
user = request.user.object if request.user.is_authenticated else None
@ -723,7 +723,7 @@ async def extract_references_and_questions(
request=request,
n=n_items,
r=True,
score_threshold=-5.0,
max_distance=d,
dedupe=False,
)
)

View file

@ -146,7 +146,7 @@ def extract_metadata(image_name):
async def query(
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf
):
# Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
@ -167,7 +167,8 @@ async def query(
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
with timer("Search Time", logger):
image_hits = {
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
# Map scores to distance metric by multiplying by -1
result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]}
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
}
@ -204,7 +205,7 @@ async def query(
]
# Filter results by score threshold
hits = [hit for hit in hits if hit["image_score"] >= score_threshold]
hits = [hit for hit in hits if hit["image_score"] <= score_threshold]
# Sort the images based on their combined metadata, image scores
return sorted(hits, key=lambda hit: hit["score"], reverse=True)

View file

@ -105,7 +105,7 @@ async def query(
type: SearchType = SearchType.All,
question_embedding: Union[torch.Tensor, None] = None,
rank_results: bool = False,
score_threshold: float = -math.inf,
max_distance: float = math.inf,
) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query"
@ -127,6 +127,7 @@ async def query(
max_results=top_k,
file_type_filter=file_type,
raw_query=raw_query,
max_distance=max_distance,
).all()
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
@ -177,12 +178,16 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
)
def rerank_and_sort_results(hits, query):
def rerank_and_sort_results(hits, query, rank_results):
# If we have more than one result and reranking is enabled
rank_results = rank_results and len(list(hits)) > 1
# Score all retrieved entries using the cross-encoder
hits = cross_encoder_score(query, hits)
if rank_results:
hits = cross_encoder_score(query, hits)
# Sort results by cross-encoder score followed by bi-encoder score
hits = sort_results(rank_results=True, hits=hits)
hits = sort_results(rank_results=rank_results, hits=hits)
return hits
@ -217,9 +222,9 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_scores = state.cross_encoder_model.predict(query, hits)
# Store cross-encoder scores in results dictionary for ranking
# Convert cross-encoder scores to distances and pass in hits for reranking
for idx in range(len(cross_scores)):
hits[idx]["cross_score"] = cross_scores[idx]
hits[idx]["cross_score"] = -1 * cross_scores[idx]
return hits
@ -227,7 +232,7 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
"""Order results by cross-encoder score followed by bi-encoder score"""
with timer("Rank Time", logger, state.device):
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
hits.sort(key=lambda x: x["score"]) # sort by bi-encoder score
if rank_results:
hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score
hits.sort(key=lambda x: x["cross_score"]) # sort by cross-encoder score
return hits