diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 1929eb18..1ba1ed4d 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -85,34 +85,15 @@ def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, em return corpus_embeddings -def query(raw_query: str, model: TextSearchModel, rank_results=False): +def query(raw_query: str, model: TextSearchModel, rank_results: bool = False): "Search for entries that answer the query" query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings # Filter query, entries and embeddings before semantic search - start_filter = time.time() - included_entry_indices = set(range(len(entries))) - filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] - for filter in filters_in_query: - query, included_entry_indices_by_filter = filter.apply(query, entries) - included_entry_indices.intersection_update(included_entry_indices_by_filter) - - # Get entries (and associated embeddings) satisfying all filters - if not included_entry_indices: - return [], [] - else: - start = time.time() - entries = [entries[id] for id in included_entry_indices] - corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)) - end = time.time() - logger.debug(f"Keep entries satisfying all filters: {end - start} seconds") - - end_filter = time.time() - logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}") - + query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters) + # If no entries left after filtering, return empty results if entries is None or len(entries) == 0: return [], [] - # If query only had filters it'll be empty now. So short-circuit and return results. if query.strip() == "": hits = [{"corpus_id": id, "score": 1.0} for id, _ in enumerate(entries)] @@ -133,34 +114,13 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): # Score all retrieved entries using the cross-encoder if rank_results: - start = time.time() - cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits] - cross_scores = model.cross_encoder.predict(cross_inp) - end = time.time() - logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}") - - # Store cross-encoder scores in results dictionary for ranking - for idx in range(len(cross_scores)): - hits[idx]['cross-score'] = cross_scores[idx] + hits = cross_encoder_score(model.cross_encoder, query, entries, hits) # Order results by cross-encoder score followed by bi-encoder score - start = time.time() - hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score - if rank_results: - hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score - end = time.time() - logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}") + hits = sort_results(rank_results, hits) # Deduplicate entries by raw entry text before showing to users - # Compiled entries are split by max tokens supported by ML models. - # This can result in duplicate hits, entries shown to user. - start = time.time() - seen, original_hits_count = set(), len(hits) - hits = [hit for hit in hits - if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)] - duplicate_hits = original_hits_count - len(hits) - end = time.time() - logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates") + hits = deduplicate_results(entries, hits) return hits, entries @@ -219,3 +179,68 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co filter.load(entries, regenerate=regenerate) return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k) + + +def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]: + '''Filter query, entries and embeddings before semantic search''' + start_filter = time.time() + included_entry_indices = set(range(len(entries))) + filters_in_query = [filter for filter in filters if filter.can_filter(query)] + for filter in filters_in_query: + query, included_entry_indices_by_filter = filter.apply(query, entries) + included_entry_indices.intersection_update(included_entry_indices_by_filter) + + # Get entries (and associated embeddings) satisfying all filters + if not included_entry_indices: + return '', [], torch.tensor([], device=state.device) + else: + start = time.time() + entries = [entries[id] for id in included_entry_indices] + corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)) + end = time.time() + logger.debug(f"Keep entries satisfying all filters: {end - start} seconds") + + end_filter = time.time() + logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}") + + return query, entries, corpus_embeddings + + +def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]: + '''Score all retrieved entries using the cross-encoder''' + start = time.time() + cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits] + cross_scores = cross_encoder.predict(cross_inp) + end = time.time() + logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}") + + # Store cross-encoder scores in results dictionary for ranking + for idx in range(len(cross_scores)): + hits[idx]['cross-score'] = cross_scores[idx] + + return hits + + +def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]: + '''Order results by cross-encoder score followed by bi-encoder score''' + start = time.time() + hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score + if rank_results: + hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score + end = time.time() + logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}") + return hits + + +def deduplicate_results(entries: list[Entry], hits: list[dict]) -> list[dict]: + '''Deduplicate entries by raw entry text before showing to users + Compiled entries are split by max tokens supported by ML models. + This can result in duplicate hits, entries shown to user.''' + start = time.time() + seen, original_hits_count = set(), len(hits) + hits = [hit for hit in hits + if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)] + duplicate_hits = original_hits_count - len(hits) + end = time.time() + logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates") + return hits