mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Split text_search.query logic into separate methods for modularity
The query method had become too big. Extract out filter, score, sort and deduplicate logic used by text_search.query into separate methods. This should improve readabilty of code.
This commit is contained in:
parent
8dc6ee8b6c
commit
afcfc3cd62
1 changed files with 71 additions and 46 deletions
|
@ -85,34 +85,15 @@ def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, em
|
|||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results: bool = False):
|
||||
"Search for entries that answer the query"
|
||||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
start_filter = time.time()
|
||||
included_entry_indices = set(range(len(entries)))
|
||||
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
|
||||
for filter in filters_in_query:
|
||||
query, included_entry_indices_by_filter = filter.apply(query, entries)
|
||||
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
||||
|
||||
# Get entries (and associated embeddings) satisfying all filters
|
||||
if not included_entry_indices:
|
||||
return [], []
|
||||
else:
|
||||
start = time.time()
|
||||
entries = [entries[id] for id in included_entry_indices]
|
||||
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
|
||||
end = time.time()
|
||||
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
|
||||
|
||||
end_filter = time.time()
|
||||
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}")
|
||||
|
||||
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
|
||||
# If no entries left after filtering, return empty results
|
||||
if entries is None or len(entries) == 0:
|
||||
return [], []
|
||||
|
||||
# If query only had filters it'll be empty now. So short-circuit and return results.
|
||||
if query.strip() == "":
|
||||
hits = [{"corpus_id": id, "score": 1.0} for id, _ in enumerate(entries)]
|
||||
|
@ -133,34 +114,13 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
|||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
if rank_results:
|
||||
start = time.time()
|
||||
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
|
||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||
end = time.time()
|
||||
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
for idx in range(len(cross_scores)):
|
||||
hits[idx]['cross-score'] = cross_scores[idx]
|
||||
hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
|
||||
|
||||
# Order results by cross-encoder score followed by bi-encoder score
|
||||
start = time.time()
|
||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
||||
if rank_results:
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||
end = time.time()
|
||||
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
hits = sort_results(rank_results, hits)
|
||||
|
||||
# Deduplicate entries by raw entry text before showing to users
|
||||
# Compiled entries are split by max tokens supported by ML models.
|
||||
# This can result in duplicate hits, entries shown to user.
|
||||
start = time.time()
|
||||
seen, original_hits_count = set(), len(hits)
|
||||
hits = [hit for hit in hits
|
||||
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
|
||||
duplicate_hits = original_hits_count - len(hits)
|
||||
end = time.time()
|
||||
logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates")
|
||||
hits = deduplicate_results(entries, hits)
|
||||
|
||||
return hits, entries
|
||||
|
||||
|
@ -219,3 +179,68 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
|
|||
filter.load(entries, regenerate=regenerate)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
|
||||
|
||||
|
||||
def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]:
|
||||
'''Filter query, entries and embeddings before semantic search'''
|
||||
start_filter = time.time()
|
||||
included_entry_indices = set(range(len(entries)))
|
||||
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
||||
for filter in filters_in_query:
|
||||
query, included_entry_indices_by_filter = filter.apply(query, entries)
|
||||
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
||||
|
||||
# Get entries (and associated embeddings) satisfying all filters
|
||||
if not included_entry_indices:
|
||||
return '', [], torch.tensor([], device=state.device)
|
||||
else:
|
||||
start = time.time()
|
||||
entries = [entries[id] for id in included_entry_indices]
|
||||
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
|
||||
end = time.time()
|
||||
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
|
||||
|
||||
end_filter = time.time()
|
||||
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}")
|
||||
|
||||
return query, entries, corpus_embeddings
|
||||
|
||||
|
||||
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]:
|
||||
'''Score all retrieved entries using the cross-encoder'''
|
||||
start = time.time()
|
||||
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
|
||||
cross_scores = cross_encoder.predict(cross_inp)
|
||||
end = time.time()
|
||||
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
for idx in range(len(cross_scores)):
|
||||
hits[idx]['cross-score'] = cross_scores[idx]
|
||||
|
||||
return hits
|
||||
|
||||
|
||||
def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]:
|
||||
'''Order results by cross-encoder score followed by bi-encoder score'''
|
||||
start = time.time()
|
||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
||||
if rank_results:
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||
end = time.time()
|
||||
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
return hits
|
||||
|
||||
|
||||
def deduplicate_results(entries: list[Entry], hits: list[dict]) -> list[dict]:
|
||||
'''Deduplicate entries by raw entry text before showing to users
|
||||
Compiled entries are split by max tokens supported by ML models.
|
||||
This can result in duplicate hits, entries shown to user.'''
|
||||
start = time.time()
|
||||
seen, original_hits_count = set(), len(hits)
|
||||
hits = [hit for hit in hits
|
||||
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
|
||||
duplicate_hits = original_hits_count - len(hits)
|
||||
end = time.time()
|
||||
logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates")
|
||||
return hits
|
||||
|
|
Loading…
Reference in a new issue