mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Add configurable filter support to Symmetric Ledger Search
This commit is contained in:
parent
50658453cd
commit
0e979587e0
2 changed files with 17 additions and 40 deletions
|
@ -74,10 +74,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||||
|
|
||||||
if (t == SearchType.Ledger or t == None) and model.ledger_search:
|
if (t == SearchType.Ledger or t == None) and model.ledger_search:
|
||||||
# query transactions
|
# query transactions
|
||||||
hits = symmetric_ledger.query(user_query, model.ledger_search)
|
hits, entries = symmetric_ledger.query(user_query, model.ledger_search)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return symmetric_ledger.collate_results(hits, model.ledger_search.entries, results_count)
|
return symmetric_ledger.collate_results(hits, entries, results_count)
|
||||||
|
|
||||||
if (t == SearchType.Image or t == None) and model.image_search:
|
if (t == SearchType.Image or t == None) and model.image_search:
|
||||||
# query transactions
|
# query transactions
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
import json
|
|
||||||
import gzip
|
|
||||||
import re
|
|
||||||
import argparse
|
import argparse
|
||||||
import pathlib
|
import pathlib
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
|
@ -62,27 +60,27 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v
|
||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def query(raw_query, model: TextSearchModel):
|
def query(raw_query, model: TextSearchModel, filters=[]):
|
||||||
"Search all notes for entries that answer the query"
|
"Search all notes for entries that answer the query"
|
||||||
# Separate natural query from explicit required, blocked words filters
|
# Copy original embeddings, entries to filter them for query
|
||||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
query = raw_query
|
||||||
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
||||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
entries = deepcopy(model.entries)
|
||||||
|
|
||||||
|
# Filter query, entries and embeddings before semantic search
|
||||||
|
for filter in filters:
|
||||||
|
query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings)
|
||||||
|
if entries is None or len(entries) == 0:
|
||||||
|
return [], []
|
||||||
|
|
||||||
# Encode the query using the bi-encoder
|
# Encode the query using the bi-encoder
|
||||||
question_embedding = model.bi_encoder.encode(query, convert_to_tensor=True)
|
question_embedding = model.bi_encoder.encode(query, convert_to_tensor=True)
|
||||||
|
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k)
|
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k)[0]
|
||||||
hits = hits[0] # Get the hits for the first query
|
|
||||||
|
|
||||||
# Filter results using explicit filters
|
|
||||||
hits = explicit_filter(hits, model.entries, required_words, blocked_words)
|
|
||||||
if hits is None or len(hits) == 0:
|
|
||||||
return hits
|
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
cross_inp = [[query, model.entries[hit['corpus_id']]] for hit in hits]
|
cross_inp = [[query, entries[hit['corpus_id']]] for hit in hits]
|
||||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||||
|
|
||||||
# Store cross-encoder scores in results dictionary for ranking
|
# Store cross-encoder scores in results dictionary for ranking
|
||||||
|
@ -93,28 +91,7 @@ def query(raw_query, model: TextSearchModel):
|
||||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by biencoder score
|
hits.sort(key=lambda x: x['score'], reverse=True) # sort by biencoder score
|
||||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross encoder score
|
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross encoder score
|
||||||
|
|
||||||
return hits
|
return hits, entries
|
||||||
|
|
||||||
|
|
||||||
def explicit_filter(hits, entries, required_words, blocked_words):
|
|
||||||
hits_by_word_set = [(set(word.lower()
|
|
||||||
for word
|
|
||||||
in re.split(
|
|
||||||
r',|\.| |\]|\[\(|\)|\{|\}',
|
|
||||||
entries[hit['corpus_id']])
|
|
||||||
if word != ""),
|
|
||||||
hit)
|
|
||||||
for hit in hits]
|
|
||||||
|
|
||||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
|
||||||
return hits
|
|
||||||
if len(required_words) > 0:
|
|
||||||
return [hit for (words_in_entry, hit) in hits_by_word_set
|
|
||||||
if required_words.intersection(words_in_entry) and not blocked_words.intersection(words_in_entry)]
|
|
||||||
if len(blocked_words) > 0:
|
|
||||||
return [hit for (words_in_entry, hit) in hits_by_word_set
|
|
||||||
if not blocked_words.intersection(words_in_entry)]
|
|
||||||
return hits
|
|
||||||
|
|
||||||
|
|
||||||
def render_results(hits, entries, count=5, display_biencoder_results=False):
|
def render_results(hits, entries, count=5, display_biencoder_results=False):
|
||||||
|
|
Loading…
Reference in a new issue