From cde11a2331a3b728f20d132ffb25b5c6c4fa959d Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 29 Sep 2021 19:18:33 -0700 Subject: [PATCH] Wrap search type enablement status in a search settings class - Cleaner, more idiomatic usage of a global variable - Simplifies mocking when testing client in pytest as setting wrapped in object rather than a simple type. So passed around by reference --- src/main.py | 30 ++++++++++++++---------------- src/utils/config.py | 11 +++++++++++ 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/main.py b/src/main.py index d041d312..069dc903 100644 --- a/src/main.py +++ b/src/main.py @@ -11,9 +11,10 @@ from fastapi import FastAPI from search_type import asymmetric, symmetric_ledger, image_search from utils.helpers import get_from_dict from utils.cli import cli -from utils.config import SearchType +from utils.config import SearchType, SearchSettings +search_settings = SearchSettings() app = FastAPI() @@ -26,7 +27,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): user_query = q results_count = n - if (t == SearchType.Notes or t == None) and notes_search_enabled: + if (t == SearchType.Notes or t == None) and search_settings.notes_search_enabled: # query notes hits = asymmetric.query_notes( user_query, @@ -39,7 +40,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): # collate and return results return asymmetric.collate_results(hits, entries, results_count) - if (t == SearchType.Music or t == None) and music_search_enabled: + if (t == SearchType.Music or t == None) and search_settings.music_search_enabled: # query music library hits = asymmetric.query_notes( user_query, @@ -52,7 +53,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): # collate and return results return asymmetric.collate_results(hits, songs, results_count) - if (t == SearchType.Ledger or t == None) and ledger_search_enabled: + if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled: # query transactions hits = symmetric_ledger.query_transactions( user_query, @@ -64,7 +65,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): # collate and return results return symmetric_ledger.collate_results(hits, transactions, results_count) - if (t == SearchType.Image or t == None) and image_search_enabled: + if (t == SearchType.Image or t == None) and search_settings.image_search_enabled: # query transactions hits = image_search.query_images( user_query, @@ -87,7 +88,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): @app.get('/regenerate') def regenerate(t: Optional[SearchType] = None): - if (t == SearchType.Notes or t == None) and notes_search_enabled: + if (t == SearchType.Notes or t == None) and search_settings.notes_search_enabled: # Extract Entries, Generate Embeddings global corpus_embeddings global entries @@ -99,7 +100,7 @@ def regenerate(t: Optional[SearchType] = None): regenerate=True, verbose=args.verbose) - if (t == SearchType.Music or t == None) and music_search_enabled: + if (t == SearchType.Music or t == None) and search_settings.music_search_enabled: # Extract Entries, Generate Song Embeddings global song_embeddings global songs @@ -111,7 +112,7 @@ def regenerate(t: Optional[SearchType] = None): regenerate=True, verbose=args.verbose) - if (t == SearchType.Ledger or t == None) and ledger_search_enabled: + if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled: # Extract Entries, Generate Embeddings global transaction_embeddings global transactions @@ -123,7 +124,7 @@ def regenerate(t: Optional[SearchType] = None): regenerate=True, verbose=args.verbose) - if (t == SearchType.Image or t == None) and image_search_enabled: + if (t == SearchType.Image or t == None) and search_settings.image_search_enabled: # Extract Images, Generate Embeddings global image_embeddings global image_metadata_embeddings @@ -143,9 +144,8 @@ if __name__ == '__main__': # Initialize Org Notes Search org_config = get_from_dict(args.config, 'content-type', 'org') - notes_search_enabled = False if org_config and ('input-files' in org_config or 'input-filter' in org_config): - notes_search_enabled = True + search_settings.notes_search_enabled = True entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = asymmetric.setup( org_config['input-files'], org_config['input-filter'], @@ -158,7 +158,7 @@ if __name__ == '__main__': song_config = get_from_dict(args.config, 'content-type', 'music') music_search_enabled = False if song_config and ('input-files' in song_config or 'input-filter' in song_config): - music_search_enabled = True + search_settings.music_search_enabled = True songs, song_embeddings, song_encoder, song_cross_encoder, song_top_k = asymmetric.setup( song_config['input-files'], song_config['input-filter'], @@ -169,9 +169,8 @@ if __name__ == '__main__': # Initialize Ledger Search ledger_config = get_from_dict(args.config, 'content-type', 'ledger') - ledger_search_enabled = False if ledger_config and ('input-files' in ledger_config or 'input-filter' in ledger_config): - ledger_search_enabled = True + search_settings.ledger_search_enabled = True transactions, transaction_embeddings, symmetric_encoder, symmetric_cross_encoder, _ = symmetric_ledger.setup( ledger_config['input-files'], ledger_config['input-filter'], @@ -182,9 +181,8 @@ if __name__ == '__main__': # Initialize Image Search image_config = get_from_dict(args.config, 'content-type', 'image') - image_search_enabled = False if image_config and 'input-directory' in image_config: - image_search_enabled = True + search_settings.image_search_enabled = True image_names, image_embeddings, image_metadata_embeddings, image_encoder = image_search.setup( pathlib.Path(image_config['input-directory']), pathlib.Path(image_config['embeddings-file']), diff --git a/src/utils/config.py b/src/utils/config.py index 8f42a396..eeecc6a6 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -1,4 +1,6 @@ +# System Packages from enum import Enum +from dataclasses import dataclass class SearchType(str, Enum): @@ -7,3 +9,12 @@ class SearchType(str, Enum): Music = "music" Image = "image" + +@dataclass +class SearchSettings(): + notes_search_enabled: bool = False + ledger_search_enabled: bool = False + music_search_enabled: bool = False + image_search_enabled: bool = False + +