mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Split text_search.query logic into separate methods for modularity
The query method had become too big. Extract out filter, score, sort and deduplicate logic used by text_search.query into separate methods. This should improve readabilty of code.
This commit is contained in:
parent
8dc6ee8b6c
commit
afcfc3cd62
1 changed files with 71 additions and 46 deletions
|
@ -85,34 +85,15 @@ def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, em
|
||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
def query(raw_query: str, model: TextSearchModel, rank_results: bool = False):
|
||||||
"Search for entries that answer the query"
|
"Search for entries that answer the query"
|
||||||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||||
|
|
||||||
# Filter query, entries and embeddings before semantic search
|
# Filter query, entries and embeddings before semantic search
|
||||||
start_filter = time.time()
|
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
|
||||||
included_entry_indices = set(range(len(entries)))
|
# If no entries left after filtering, return empty results
|
||||||
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
|
|
||||||
for filter in filters_in_query:
|
|
||||||
query, included_entry_indices_by_filter = filter.apply(query, entries)
|
|
||||||
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
|
||||||
|
|
||||||
# Get entries (and associated embeddings) satisfying all filters
|
|
||||||
if not included_entry_indices:
|
|
||||||
return [], []
|
|
||||||
else:
|
|
||||||
start = time.time()
|
|
||||||
entries = [entries[id] for id in included_entry_indices]
|
|
||||||
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
|
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
|
|
||||||
|
|
||||||
end_filter = time.time()
|
|
||||||
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}")
|
|
||||||
|
|
||||||
if entries is None or len(entries) == 0:
|
if entries is None or len(entries) == 0:
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
# If query only had filters it'll be empty now. So short-circuit and return results.
|
# If query only had filters it'll be empty now. So short-circuit and return results.
|
||||||
if query.strip() == "":
|
if query.strip() == "":
|
||||||
hits = [{"corpus_id": id, "score": 1.0} for id, _ in enumerate(entries)]
|
hits = [{"corpus_id": id, "score": 1.0} for id, _ in enumerate(entries)]
|
||||||
|
@ -133,34 +114,13 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
if rank_results:
|
if rank_results:
|
||||||
start = time.time()
|
hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
|
||||||
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
|
|
||||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
|
|
||||||
|
|
||||||
# Store cross-encoder scores in results dictionary for ranking
|
|
||||||
for idx in range(len(cross_scores)):
|
|
||||||
hits[idx]['cross-score'] = cross_scores[idx]
|
|
||||||
|
|
||||||
# Order results by cross-encoder score followed by bi-encoder score
|
# Order results by cross-encoder score followed by bi-encoder score
|
||||||
start = time.time()
|
hits = sort_results(rank_results, hits)
|
||||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
|
||||||
if rank_results:
|
|
||||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
|
|
||||||
|
|
||||||
# Deduplicate entries by raw entry text before showing to users
|
# Deduplicate entries by raw entry text before showing to users
|
||||||
# Compiled entries are split by max tokens supported by ML models.
|
hits = deduplicate_results(entries, hits)
|
||||||
# This can result in duplicate hits, entries shown to user.
|
|
||||||
start = time.time()
|
|
||||||
seen, original_hits_count = set(), len(hits)
|
|
||||||
hits = [hit for hit in hits
|
|
||||||
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
|
|
||||||
duplicate_hits = original_hits_count - len(hits)
|
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates")
|
|
||||||
|
|
||||||
return hits, entries
|
return hits, entries
|
||||||
|
|
||||||
|
@ -219,3 +179,68 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
|
||||||
filter.load(entries, regenerate=regenerate)
|
filter.load(entries, regenerate=regenerate)
|
||||||
|
|
||||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
|
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]:
|
||||||
|
'''Filter query, entries and embeddings before semantic search'''
|
||||||
|
start_filter = time.time()
|
||||||
|
included_entry_indices = set(range(len(entries)))
|
||||||
|
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
||||||
|
for filter in filters_in_query:
|
||||||
|
query, included_entry_indices_by_filter = filter.apply(query, entries)
|
||||||
|
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
||||||
|
|
||||||
|
# Get entries (and associated embeddings) satisfying all filters
|
||||||
|
if not included_entry_indices:
|
||||||
|
return '', [], torch.tensor([], device=state.device)
|
||||||
|
else:
|
||||||
|
start = time.time()
|
||||||
|
entries = [entries[id] for id in included_entry_indices]
|
||||||
|
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
|
||||||
|
end = time.time()
|
||||||
|
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
|
||||||
|
|
||||||
|
end_filter = time.time()
|
||||||
|
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}")
|
||||||
|
|
||||||
|
return query, entries, corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]:
|
||||||
|
'''Score all retrieved entries using the cross-encoder'''
|
||||||
|
start = time.time()
|
||||||
|
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
|
||||||
|
cross_scores = cross_encoder.predict(cross_inp)
|
||||||
|
end = time.time()
|
||||||
|
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
|
||||||
|
|
||||||
|
# Store cross-encoder scores in results dictionary for ranking
|
||||||
|
for idx in range(len(cross_scores)):
|
||||||
|
hits[idx]['cross-score'] = cross_scores[idx]
|
||||||
|
|
||||||
|
return hits
|
||||||
|
|
||||||
|
|
||||||
|
def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]:
|
||||||
|
'''Order results by cross-encoder score followed by bi-encoder score'''
|
||||||
|
start = time.time()
|
||||||
|
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
||||||
|
if rank_results:
|
||||||
|
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||||
|
end = time.time()
|
||||||
|
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
|
||||||
|
return hits
|
||||||
|
|
||||||
|
|
||||||
|
def deduplicate_results(entries: list[Entry], hits: list[dict]) -> list[dict]:
|
||||||
|
'''Deduplicate entries by raw entry text before showing to users
|
||||||
|
Compiled entries are split by max tokens supported by ML models.
|
||||||
|
This can result in duplicate hits, entries shown to user.'''
|
||||||
|
start = time.time()
|
||||||
|
seen, original_hits_count = set(), len(hits)
|
||||||
|
hits = [hit for hit in hits
|
||||||
|
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
|
||||||
|
duplicate_hits = original_hits_count - len(hits)
|
||||||
|
end = time.time()
|
||||||
|
logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates")
|
||||||
|
return hits
|
||||||
|
|
Loading…
Reference in a new issue