mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Handle filter only queries. Short-circuit and return filtered results
- For queries with only filters in them short-circuit and return filtered results. No need to run semantic search, re-ranking. - Add client test for filter only query and quote query in client tests
This commit is contained in:
parent
afc84de234
commit
1bfe9c4ef2
2 changed files with 29 additions and 5 deletions
|
@ -112,6 +112,11 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
||||||
if entries is None or len(entries) == 0:
|
if entries is None or len(entries) == 0:
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
|
# If query only had filters it'll be empty now. So short-circuit and return results.
|
||||||
|
if query.strip() == "":
|
||||||
|
hits = [{"corpus_id": id, "score": 1.0} for id, _ in enumerate(entries)]
|
||||||
|
return hits, entries
|
||||||
|
|
||||||
# Encode the query using the bi-encoder
|
# Encode the query using the bi-encoder
|
||||||
start = time.time()
|
start = time.time()
|
||||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||||
|
|
|
@ -1,18 +1,20 @@
|
||||||
# Standard Modules
|
# Standard Modules
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.main import app
|
from src.main import app
|
||||||
from src.utils.config import SearchType
|
|
||||||
from src.utils.state import model, config
|
from src.utils.state import model, config
|
||||||
from src.search_type import text_search, image_search
|
from src.search_type import text_search, image_search
|
||||||
from src.utils.rawconfig import ContentConfig, SearchConfig
|
from src.utils.rawconfig import ContentConfig, SearchConfig
|
||||||
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||||
from src.search_filter.word_filter import WordFilter
|
from src.search_filter.word_filter import WordFilter
|
||||||
|
from src.search_filter.file_filter import FileFilter
|
||||||
|
|
||||||
|
|
||||||
# Arrange
|
# Arrange
|
||||||
|
@ -23,7 +25,7 @@ client = TestClient(app)
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_search_with_invalid_content_type():
|
def test_search_with_invalid_content_type():
|
||||||
# Arrange
|
# Arrange
|
||||||
user_query = "How to call Khoj from Emacs?"
|
user_query = quote("How to call Khoj from Emacs?")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/search?q={user_query}&t=invalid_content_type")
|
response = client.get(f"/search?q={user_query}&t=invalid_content_type")
|
||||||
|
@ -117,7 +119,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
|
||||||
def test_notes_search(content_config: ContentConfig, search_config: SearchConfig):
|
def test_notes_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||||
user_query = "How to git install application?"
|
user_query = quote("How to git install application?")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/search?q={user_query}&n=1&t=org&r=true")
|
response = client.get(f"/search?q={user_query}&n=1&t=org&r=true")
|
||||||
|
@ -129,12 +131,29 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
|
||||||
assert "git clone" in search_result
|
assert "git clone" in search_result
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def test_notes_search_with_only_filters(content_config: ContentConfig, search_config: SearchConfig):
|
||||||
|
# Arrange
|
||||||
|
filters = [WordFilter(), FileFilter()]
|
||||||
|
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
||||||
|
user_query = quote('+"Emacs" file:"*.org"')
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get(f"/search?q={user_query}&n=1&t=org")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
# assert actual_data contains word "Emacs"
|
||||||
|
search_result = response.json()[0]["entry"]
|
||||||
|
assert "Emacs" in search_result
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
|
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
filters = [WordFilter()]
|
filters = [WordFilter()]
|
||||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
||||||
user_query = 'How to git install application? +"Emacs"'
|
user_query = quote('How to git install application? +"Emacs"')
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/search?q={user_query}&n=1&t=org")
|
response = client.get(f"/search?q={user_query}&n=1&t=org")
|
||||||
|
@ -151,7 +170,7 @@ def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_
|
||||||
# Arrange
|
# Arrange
|
||||||
filters = [WordFilter()]
|
filters = [WordFilter()]
|
||||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
||||||
user_query = 'How to git install application? -"clone"'
|
user_query = quote('How to git install application? -"clone"')
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/search?q={user_query}&n=1&t=org")
|
response = client.get(f"/search?q={user_query}&n=1&t=org")
|
||||||
|
|
Loading…
Reference in a new issue