Split text_search.query logic into separate methods for modularity

The query method had become too big.

Extract out filter, score, sort and deduplicate logic used by
text_search.query into separate methods.

This should improve readabilty of code.
This commit is contained in:
Debanjum Singh Solanky 2023-01-09 17:26:41 -03:00
parent 8dc6ee8b6c
commit afcfc3cd62

View file

@ -85,34 +85,15 @@ def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, em
return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, rank_results=False):
def query(raw_query: str, model: TextSearchModel, rank_results: bool = False):
"Search for entries that answer the query"
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
# Filter query, entries and embeddings before semantic search
start_filter = time.time()
included_entry_indices = set(range(len(entries)))
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
for filter in filters_in_query:
query, included_entry_indices_by_filter = filter.apply(query, entries)
included_entry_indices.intersection_update(included_entry_indices_by_filter)
# Get entries (and associated embeddings) satisfying all filters
if not included_entry_indices:
return [], []
else:
start = time.time()
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
end = time.time()
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
end_filter = time.time()
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}")
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
# If no entries left after filtering, return empty results
if entries is None or len(entries) == 0:
return [], []
# If query only had filters it'll be empty now. So short-circuit and return results.
if query.strip() == "":
hits = [{"corpus_id": id, "score": 1.0} for id, _ in enumerate(entries)]
@ -133,34 +114,13 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
# Score all retrieved entries using the cross-encoder
if rank_results:
start = time.time()
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp)
end = time.time()
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
# Order results by cross-encoder score followed by bi-encoder score
start = time.time()
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
end = time.time()
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
hits = sort_results(rank_results, hits)
# Deduplicate entries by raw entry text before showing to users
# Compiled entries are split by max tokens supported by ML models.
# This can result in duplicate hits, entries shown to user.
start = time.time()
seen, original_hits_count = set(), len(hits)
hits = [hit for hit in hits
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
duplicate_hits = original_hits_count - len(hits)
end = time.time()
logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates")
hits = deduplicate_results(entries, hits)
return hits, entries
@ -219,3 +179,68 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
filter.load(entries, regenerate=regenerate)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]:
'''Filter query, entries and embeddings before semantic search'''
start_filter = time.time()
included_entry_indices = set(range(len(entries)))
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
for filter in filters_in_query:
query, included_entry_indices_by_filter = filter.apply(query, entries)
included_entry_indices.intersection_update(included_entry_indices_by_filter)
# Get entries (and associated embeddings) satisfying all filters
if not included_entry_indices:
return '', [], torch.tensor([], device=state.device)
else:
start = time.time()
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
end = time.time()
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
end_filter = time.time()
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}")
return query, entries, corpus_embeddings
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]:
'''Score all retrieved entries using the cross-encoder'''
start = time.time()
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
end = time.time()
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
return hits
def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]:
'''Order results by cross-encoder score followed by bi-encoder score'''
start = time.time()
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
end = time.time()
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
return hits
def deduplicate_results(entries: list[Entry], hits: list[dict]) -> list[dict]:
'''Deduplicate entries by raw entry text before showing to users
Compiled entries are split by max tokens supported by ML models.
This can result in duplicate hits, entries shown to user.'''
start = time.time()
seen, original_hits_count = set(), len(hits)
hits = [hit for hit in hits
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
duplicate_hits = original_hits_count - len(hits)
end = time.time()
logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates")
return hits