From 81ce0cacc366bd0bdaa10942f394beb23a81d061 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 29 Sep 2021 19:02:55 -0700 Subject: [PATCH] Only allow supported search types to /search, /regenerate APIs - Use a SearchType to limit types that can be passed by user - FastAPI automatically validates type passed in query param - Available type options show up in Swagger UI, FastAPI docs - controller code looks neater instead of doing string comparisons for type - Test invalid, valid search types via pytest --- src/main.py | 21 +++++++++++---------- src/tests/test_main.py | 38 ++++++++++++++++++++++++++++++++++++++ src/utils/config.py | 9 +++++++++ 3 files changed, 58 insertions(+), 10 deletions(-) create mode 100644 src/utils/config.py diff --git a/src/main.py b/src/main.py index d0d78bd0..d041d312 100644 --- a/src/main.py +++ b/src/main.py @@ -11,13 +11,14 @@ from fastapi import FastAPI from search_type import asymmetric, symmetric_ledger, image_search from utils.helpers import get_from_dict from utils.cli import cli +from utils.config import SearchType app = FastAPI() @app.get('/search') -def search(q: str, n: Optional[int] = 5, t: Optional[str] = None): +def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if q is None or q == '': print(f'No query param (q) passed in API call to initiate search') return {} @@ -25,7 +26,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None): user_query = q results_count = n - if (t == 'notes' or t == None) and notes_search_enabled: + if (t == SearchType.Notes or t == None) and notes_search_enabled: # query notes hits = asymmetric.query_notes( user_query, @@ -38,7 +39,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None): # collate and return results return asymmetric.collate_results(hits, entries, results_count) - if (t == 'music' or t == None) and music_search_enabled: + if (t == SearchType.Music or t == None) and music_search_enabled: # query music library hits = asymmetric.query_notes( user_query, @@ -51,7 +52,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None): # collate and return results return asymmetric.collate_results(hits, songs, results_count) - if (t == 'ledger' or t == None) and ledger_search_enabled: + if (t == SearchType.Ledger or t == None) and ledger_search_enabled: # query transactions hits = symmetric_ledger.query_transactions( user_query, @@ -63,7 +64,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None): # collate and return results return symmetric_ledger.collate_results(hits, transactions, results_count) - if (t == 'image' or t == None) and image_search_enabled: + if (t == SearchType.Image or t == None) and image_search_enabled: # query transactions hits = image_search.query_images( user_query, @@ -85,8 +86,8 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None): @app.get('/regenerate') -def regenerate(t: Optional[str] = None): - if (t == 'notes' or t == None) and notes_search_enabled: +def regenerate(t: Optional[SearchType] = None): + if (t == SearchType.Notes or t == None) and notes_search_enabled: # Extract Entries, Generate Embeddings global corpus_embeddings global entries @@ -98,7 +99,7 @@ def regenerate(t: Optional[str] = None): regenerate=True, verbose=args.verbose) - if (t == 'music' or t == None) and music_search_enabled: + if (t == SearchType.Music or t == None) and music_search_enabled: # Extract Entries, Generate Song Embeddings global song_embeddings global songs @@ -110,7 +111,7 @@ def regenerate(t: Optional[str] = None): regenerate=True, verbose=args.verbose) - if (t == 'ledger' or t == None) and ledger_search_enabled: + if (t == SearchType.Ledger or t == None) and ledger_search_enabled: # Extract Entries, Generate Embeddings global transaction_embeddings global transactions @@ -122,7 +123,7 @@ def regenerate(t: Optional[str] = None): regenerate=True, verbose=args.verbose) - if (t == 'image' or t == None) and image_search_enabled: + if (t == SearchType.Image or t == None) and image_search_enabled: # Extract Images, Generate Embeddings global image_embeddings global image_metadata_embeddings diff --git a/src/tests/test_main.py b/src/tests/test_main.py index 059d9190..764d9b22 100644 --- a/src/tests/test_main.py +++ b/src/tests/test_main.py @@ -16,6 +16,44 @@ client = TestClient(app) # Test +# ---------------------------------------------------------------------------------------------------- +def test_search_with_invalid_search_type(): + # Arrange + user_query = "How to call semantic search from Emacs?" + + # Act + response = client.get(f"/search?q={user_query}&t=invalid_search_type") + + # Assert + assert response.status_code == 422 + + +def test_search_with_valid_search_type(): + # Arrange + for search_type in ["notes", "ledger", "music", "image"]: + # Act + response = client.get(f"/search?q=random&t={search_type}") + # Assert + assert response.status_code == 200 + + +def test_regenerate_with_invalid_search_type(): + # Act + response = client.get(f"/regenerate?t=invalid_search_type") + + # Assert + assert response.status_code == 422 + + +def test_regenerate_with_valid_search_type(): + # Arrange + for search_type in ["notes", "ledger", "music", "image"]: + # Act + response = client.get(f"/regenerate?t={search_type}") + # Assert + assert response.status_code == 200 + + # ---------------------------------------------------------------------------------------------------- def test_asymmetric_setup(): # Arrange diff --git a/src/utils/config.py b/src/utils/config.py new file mode 100644 index 00000000..8f42a396 --- /dev/null +++ b/src/utils/config.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class SearchType(str, Enum): + Notes = "notes" + Ledger = "ledger" + Music = "music" + Image = "image" +