mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Only allow supported search types to /search, /regenerate APIs
- Use a SearchType to limit types that can be passed by user - FastAPI automatically validates type passed in query param - Available type options show up in Swagger UI, FastAPI docs - controller code looks neater instead of doing string comparisons for type - Test invalid, valid search types via pytest
This commit is contained in:
parent
150593c776
commit
81ce0cacc3
3 changed files with 58 additions and 10 deletions
21
src/main.py
21
src/main.py
|
@ -11,13 +11,14 @@ from fastapi import FastAPI
|
|||
from search_type import asymmetric, symmetric_ledger, image_search
|
||||
from utils.helpers import get_from_dict
|
||||
from utils.cli import cli
|
||||
from utils.config import SearchType
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get('/search')
|
||||
def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
||||
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||
if q is None or q == '':
|
||||
print(f'No query param (q) passed in API call to initiate search')
|
||||
return {}
|
||||
|
@ -25,7 +26,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
|||
user_query = q
|
||||
results_count = n
|
||||
|
||||
if (t == 'notes' or t == None) and notes_search_enabled:
|
||||
if (t == SearchType.Notes or t == None) and notes_search_enabled:
|
||||
# query notes
|
||||
hits = asymmetric.query_notes(
|
||||
user_query,
|
||||
|
@ -38,7 +39,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
|||
# collate and return results
|
||||
return asymmetric.collate_results(hits, entries, results_count)
|
||||
|
||||
if (t == 'music' or t == None) and music_search_enabled:
|
||||
if (t == SearchType.Music or t == None) and music_search_enabled:
|
||||
# query music library
|
||||
hits = asymmetric.query_notes(
|
||||
user_query,
|
||||
|
@ -51,7 +52,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
|||
# collate and return results
|
||||
return asymmetric.collate_results(hits, songs, results_count)
|
||||
|
||||
if (t == 'ledger' or t == None) and ledger_search_enabled:
|
||||
if (t == SearchType.Ledger or t == None) and ledger_search_enabled:
|
||||
# query transactions
|
||||
hits = symmetric_ledger.query_transactions(
|
||||
user_query,
|
||||
|
@ -63,7 +64,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
|||
# collate and return results
|
||||
return symmetric_ledger.collate_results(hits, transactions, results_count)
|
||||
|
||||
if (t == 'image' or t == None) and image_search_enabled:
|
||||
if (t == SearchType.Image or t == None) and image_search_enabled:
|
||||
# query transactions
|
||||
hits = image_search.query_images(
|
||||
user_query,
|
||||
|
@ -85,8 +86,8 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
|||
|
||||
|
||||
@app.get('/regenerate')
|
||||
def regenerate(t: Optional[str] = None):
|
||||
if (t == 'notes' or t == None) and notes_search_enabled:
|
||||
def regenerate(t: Optional[SearchType] = None):
|
||||
if (t == SearchType.Notes or t == None) and notes_search_enabled:
|
||||
# Extract Entries, Generate Embeddings
|
||||
global corpus_embeddings
|
||||
global entries
|
||||
|
@ -98,7 +99,7 @@ def regenerate(t: Optional[str] = None):
|
|||
regenerate=True,
|
||||
verbose=args.verbose)
|
||||
|
||||
if (t == 'music' or t == None) and music_search_enabled:
|
||||
if (t == SearchType.Music or t == None) and music_search_enabled:
|
||||
# Extract Entries, Generate Song Embeddings
|
||||
global song_embeddings
|
||||
global songs
|
||||
|
@ -110,7 +111,7 @@ def regenerate(t: Optional[str] = None):
|
|||
regenerate=True,
|
||||
verbose=args.verbose)
|
||||
|
||||
if (t == 'ledger' or t == None) and ledger_search_enabled:
|
||||
if (t == SearchType.Ledger or t == None) and ledger_search_enabled:
|
||||
# Extract Entries, Generate Embeddings
|
||||
global transaction_embeddings
|
||||
global transactions
|
||||
|
@ -122,7 +123,7 @@ def regenerate(t: Optional[str] = None):
|
|||
regenerate=True,
|
||||
verbose=args.verbose)
|
||||
|
||||
if (t == 'image' or t == None) and image_search_enabled:
|
||||
if (t == SearchType.Image or t == None) and image_search_enabled:
|
||||
# Extract Images, Generate Embeddings
|
||||
global image_embeddings
|
||||
global image_metadata_embeddings
|
||||
|
|
|
@ -16,6 +16,44 @@ client = TestClient(app)
|
|||
|
||||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_search_with_invalid_search_type():
|
||||
# Arrange
|
||||
user_query = "How to call semantic search from Emacs?"
|
||||
|
||||
# Act
|
||||
response = client.get(f"/search?q={user_query}&t=invalid_search_type")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_search_with_valid_search_type():
|
||||
# Arrange
|
||||
for search_type in ["notes", "ledger", "music", "image"]:
|
||||
# Act
|
||||
response = client.get(f"/search?q=random&t={search_type}")
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_regenerate_with_invalid_search_type():
|
||||
# Act
|
||||
response = client.get(f"/regenerate?t=invalid_search_type")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_regenerate_with_valid_search_type():
|
||||
# Arrange
|
||||
for search_type in ["notes", "ledger", "music", "image"]:
|
||||
# Act
|
||||
response = client.get(f"/regenerate?t={search_type}")
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_setup():
|
||||
# Arrange
|
||||
|
|
9
src/utils/config.py
Normal file
9
src/utils/config.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
Notes = "notes"
|
||||
Ledger = "ledger"
|
||||
Music = "music"
|
||||
Image = "image"
|
||||
|
Loading…
Reference in a new issue