mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Add explicit filters to asymmetric search
User can filter results to ones which include, exclude specified words To show entries which include, exclude specific words, user should prepend a '+', '-' before the word. E.g "+hello -bye"
This commit is contained in:
parent
91a2c598fe
commit
660e6c3937
1 changed files with 34 additions and 1 deletions
|
@ -6,6 +6,7 @@ import time
|
|||
import gzip
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import torch
|
||||
import argparse
|
||||
import pathlib
|
||||
|
@ -56,8 +57,13 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, verbose=False):
|
|||
return corpus_embeddings
|
||||
|
||||
|
||||
def query_notes(query, corpus_embeddings, entries, bi_encoder, cross_encoder, topk=100):
|
||||
def query_notes(raw_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k=100):
|
||||
"Search all notes for entries that answer the query"
|
||||
# Separate natural query from explicit required, blocked words filters
|
||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||
|
||||
# Encode the query using the bi-encoder
|
||||
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
||||
|
||||
|
@ -65,6 +71,11 @@ def query_notes(query, corpus_embeddings, entries, bi_encoder, cross_encoder, to
|
|||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
|
||||
hits = hits[0] # Get the hits for the first query
|
||||
|
||||
# Filter results using explicit filters
|
||||
hits = explicit_filter(hits, entries, required_words, blocked_words)
|
||||
if hits is None or len(hits) == 0:
|
||||
return hits
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
cross_inp = [[query, entries[hit['corpus_id']]] for hit in hits]
|
||||
cross_scores = cross_encoder.predict(cross_inp)
|
||||
|
@ -76,6 +87,28 @@ def query_notes(query, corpus_embeddings, entries, bi_encoder, cross_encoder, to
|
|||
# Order results by cross encoder score followed by biencoder score
|
||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by biencoder score
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross encoder score
|
||||
|
||||
return hits
|
||||
|
||||
|
||||
def explicit_filter(hits, entries, required_words, blocked_words):
|
||||
hits_by_word_set = [(set(word.lower()
|
||||
for word
|
||||
in re.split(
|
||||
',|\.| |\]|\[\(|\)|\{|\}',
|
||||
entries[hit['corpus_id']])
|
||||
if word != ""),
|
||||
hit)
|
||||
for hit in hits]
|
||||
|
||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||
return hits
|
||||
if len(required_words) > 0:
|
||||
return [hit for (words_in_entry, hit) in hits_by_word_set
|
||||
if required_words.intersection(words_in_entry) and not blocked_words.intersection(words_in_entry)]
|
||||
if len(blocked_words) > 0:
|
||||
return [hit for (words_in_entry, hit) in hits_by_word_set
|
||||
if not blocked_words.intersection(words_in_entry)]
|
||||
return hits
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue