Only allow supported search types to /search, /regenerate APIs

- Use a SearchType to limit types that can be passed by user
- FastAPI automatically validates type passed in query param
- Available type options show up in Swagger UI, FastAPI docs
- controller code looks neater instead of doing string comparisons for type
- Test invalid, valid search types via pytest
This commit is contained in:
Debanjum Singh Solanky 2021-09-29 19:02:55 -07:00
parent 150593c776
commit 81ce0cacc3
3 changed files with 58 additions and 10 deletions

View file

@ -11,13 +11,14 @@ from fastapi import FastAPI
from search_type import asymmetric, symmetric_ledger, image_search
from utils.helpers import get_from_dict
from utils.cli import cli
from utils.config import SearchType
app = FastAPI()
@app.get('/search')
def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if q is None or q == '':
print(f'No query param (q) passed in API call to initiate search')
return {}
@ -25,7 +26,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
user_query = q
results_count = n
if (t == 'notes' or t == None) and notes_search_enabled:
if (t == SearchType.Notes or t == None) and notes_search_enabled:
# query notes
hits = asymmetric.query_notes(
user_query,
@ -38,7 +39,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
# collate and return results
return asymmetric.collate_results(hits, entries, results_count)
if (t == 'music' or t == None) and music_search_enabled:
if (t == SearchType.Music or t == None) and music_search_enabled:
# query music library
hits = asymmetric.query_notes(
user_query,
@ -51,7 +52,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
# collate and return results
return asymmetric.collate_results(hits, songs, results_count)
if (t == 'ledger' or t == None) and ledger_search_enabled:
if (t == SearchType.Ledger or t == None) and ledger_search_enabled:
# query transactions
hits = symmetric_ledger.query_transactions(
user_query,
@ -63,7 +64,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
# collate and return results
return symmetric_ledger.collate_results(hits, transactions, results_count)
if (t == 'image' or t == None) and image_search_enabled:
if (t == SearchType.Image or t == None) and image_search_enabled:
# query transactions
hits = image_search.query_images(
user_query,
@ -85,8 +86,8 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
@app.get('/regenerate')
def regenerate(t: Optional[str] = None):
if (t == 'notes' or t == None) and notes_search_enabled:
def regenerate(t: Optional[SearchType] = None):
if (t == SearchType.Notes or t == None) and notes_search_enabled:
# Extract Entries, Generate Embeddings
global corpus_embeddings
global entries
@ -98,7 +99,7 @@ def regenerate(t: Optional[str] = None):
regenerate=True,
verbose=args.verbose)
if (t == 'music' or t == None) and music_search_enabled:
if (t == SearchType.Music or t == None) and music_search_enabled:
# Extract Entries, Generate Song Embeddings
global song_embeddings
global songs
@ -110,7 +111,7 @@ def regenerate(t: Optional[str] = None):
regenerate=True,
verbose=args.verbose)
if (t == 'ledger' or t == None) and ledger_search_enabled:
if (t == SearchType.Ledger or t == None) and ledger_search_enabled:
# Extract Entries, Generate Embeddings
global transaction_embeddings
global transactions
@ -122,7 +123,7 @@ def regenerate(t: Optional[str] = None):
regenerate=True,
verbose=args.verbose)
if (t == 'image' or t == None) and image_search_enabled:
if (t == SearchType.Image or t == None) and image_search_enabled:
# Extract Images, Generate Embeddings
global image_embeddings
global image_metadata_embeddings

View file

@ -16,6 +16,44 @@ client = TestClient(app)
# Test
# ----------------------------------------------------------------------------------------------------
def test_search_with_invalid_search_type():
# Arrange
user_query = "How to call semantic search from Emacs?"
# Act
response = client.get(f"/search?q={user_query}&t=invalid_search_type")
# Assert
assert response.status_code == 422
def test_search_with_valid_search_type():
# Arrange
for search_type in ["notes", "ledger", "music", "image"]:
# Act
response = client.get(f"/search?q=random&t={search_type}")
# Assert
assert response.status_code == 200
def test_regenerate_with_invalid_search_type():
# Act
response = client.get(f"/regenerate?t=invalid_search_type")
# Assert
assert response.status_code == 422
def test_regenerate_with_valid_search_type():
# Arrange
for search_type in ["notes", "ledger", "music", "image"]:
# Act
response = client.get(f"/regenerate?t={search_type}")
# Assert
assert response.status_code == 200
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup():
# Arrange

9
src/utils/config.py Normal file
View file

@ -0,0 +1,9 @@
from enum import Enum
class SearchType(str, Enum):
Notes = "notes"
Ledger = "ledger"
Music = "music"
Image = "image"