diff --git a/src/main.py b/src/main.py index ce18b6a1..c55462a1 100644 --- a/src/main.py +++ b/src/main.py @@ -38,43 +38,26 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if (t == SearchType.Music or t == None) and search_settings.music_search_enabled: # query music library - hits = asymmetric.query_notes( - user_query, - song_embeddings, - songs, - song_encoder, - song_cross_encoder, - song_top_k) + hits = asymmetric.query_notes(user_query, model.music_search) # collate and return results - return asymmetric.collate_results(hits, songs, results_count) + return asymmetric.collate_results(hits, model.music_search.entries, results_count) if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled: # query transactions - hits = symmetric_ledger.query_transactions( - user_query, - transaction_embeddings, - transactions, - symmetric_encoder, - symmetric_cross_encoder) + hits = symmetric_ledger.query_transactions(user_query, model.ledger_search) # collate and return results - return symmetric_ledger.collate_results(hits, transactions, results_count) + return symmetric_ledger.collate_results(hits, model.ledger_search.entries, results_count) if (t == SearchType.Image or t == None) and search_settings.image_search_enabled: # query transactions - hits = image_search.query_images( - user_query, - image_embeddings, - image_metadata_embeddings, - image_encoder, - results_count, - args.verbose) + hits = image_search.query_images(user_query, model.image_search, args.verbose) # collate and return results return image_search.collate_results( hits, - image_names, + model.image_search.image_names, image_config['input-directory'], results_count) @@ -96,9 +79,7 @@ def regenerate(t: Optional[SearchType] = None): if (t == SearchType.Music or t == None) and search_settings.music_search_enabled: # Extract Entries, Generate Song Embeddings - global song_embeddings - global songs - songs, song_embeddings, _, _, _ = asymmetric.setup( + model.music_search = asymmetric.setup( song_config['input-files'], song_config['input-filter'], pathlib.Path(song_config['compressed-jsonl']), @@ -108,9 +89,7 @@ def regenerate(t: Optional[SearchType] = None): if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled: # Extract Entries, Generate Embeddings - global transaction_embeddings - global transactions - transactions, transaction_embeddings, _, _, _ = symmetric_ledger.setup( + model.ledger_search = symmetric_ledger.setup( ledger_config['input-files'], ledger_config['input-filter'], pathlib.Path(ledger_config['compressed-jsonl']), @@ -120,11 +99,7 @@ def regenerate(t: Optional[SearchType] = None): if (t == SearchType.Image or t == None) and search_settings.image_search_enabled: # Extract Images, Generate Embeddings - global image_embeddings - global image_metadata_embeddings - global image_names - - image_names, image_embeddings, image_metadata_embeddings, _ = image_search.setup( + model.image_search = image_search.setup( pathlib.Path(image_config['input-directory']), pathlib.Path(image_config['embeddings-file']), regenerate=True, @@ -153,7 +128,7 @@ if __name__ == '__main__': music_search_enabled = False if song_config and ('input-files' in song_config or 'input-filter' in song_config): search_settings.music_search_enabled = True - songs, song_embeddings, song_encoder, song_cross_encoder, song_top_k = asymmetric.setup( + model.music_search = asymmetric.setup( song_config['input-files'], song_config['input-filter'], pathlib.Path(song_config['compressed-jsonl']), @@ -165,7 +140,7 @@ if __name__ == '__main__': ledger_config = get_from_dict(args.config, 'content-type', 'ledger') if ledger_config and ('input-files' in ledger_config or 'input-filter' in ledger_config): search_settings.ledger_search_enabled = True - transactions, transaction_embeddings, symmetric_encoder, symmetric_cross_encoder, _ = symmetric_ledger.setup( + model.ledger_search = symmetric_ledger.setup( ledger_config['input-files'], ledger_config['input-filter'], pathlib.Path(ledger_config['compressed-jsonl']), @@ -177,7 +152,7 @@ if __name__ == '__main__': image_config = get_from_dict(args.config, 'content-type', 'image') if image_config and 'input-directory' in image_config: search_settings.image_search_enabled = True - image_names, image_embeddings, image_metadata_embeddings, image_encoder = image_search.setup( + model.image_search = image_search.setup( pathlib.Path(image_config['input-directory']), pathlib.Path(image_config['embeddings-file']), batch_size=image_config['batch-size'], diff --git a/src/utils/config.py b/src/utils/config.py index dc09397b..0f0bf960 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -27,6 +27,26 @@ class AsymmetricSearchModel(): self.top_k = top_k +class LedgerSearchModel(): + def __init__(self, transactions, transaction_embeddings, symmetric_encoder, symmetric_cross_encoder, top_k): + self.transactions = transactions + self.transaction_embeddings = transaction_embeddings + self.symmetric_encoder = symmetric_encoder + self.symmetric_cross_encoder = symmetric_cross_encoder + self.top_k = top_k + + +class ImageSearchModel(): + def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder): + self.image_names = image_names + self.image_embeddings = image_embeddings + self.image_metadata_embeddings = image_metadata_embeddings + self.image_encoder = image_encoder + + @dataclass class SearchModels(): notes_search: AsymmetricSearchModel = None + ledger_search: LedgerSearchModel = None + music_search: AsymmetricSearchModel = None + image_search: ImageSearchModel = None