Use type specific model for other search types too. Expose them via SearchModels

- Wrap Image, Music, Ledger search into the type of SearchModel they use
  Similar to what was done for notes model by wrapping it's config
  into an AsymmetricSearchModel.

- Use the uber wrapper class to expose all type specific search models
This commit is contained in:
Debanjum Singh Solanky 2021-09-29 21:09:42 -07:00
parent 352d2930ee
commit f4dd9cd117
2 changed files with 32 additions and 37 deletions

View file

@ -38,43 +38,26 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if (t == SearchType.Music or t == None) and search_settings.music_search_enabled: if (t == SearchType.Music or t == None) and search_settings.music_search_enabled:
# query music library # query music library
hits = asymmetric.query_notes( hits = asymmetric.query_notes(user_query, model.music_search)
user_query,
song_embeddings,
songs,
song_encoder,
song_cross_encoder,
song_top_k)
# collate and return results # collate and return results
return asymmetric.collate_results(hits, songs, results_count) return asymmetric.collate_results(hits, model.music_search.entries, results_count)
if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled: if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled:
# query transactions # query transactions
hits = symmetric_ledger.query_transactions( hits = symmetric_ledger.query_transactions(user_query, model.ledger_search)
user_query,
transaction_embeddings,
transactions,
symmetric_encoder,
symmetric_cross_encoder)
# collate and return results # collate and return results
return symmetric_ledger.collate_results(hits, transactions, results_count) return symmetric_ledger.collate_results(hits, model.ledger_search.entries, results_count)
if (t == SearchType.Image or t == None) and search_settings.image_search_enabled: if (t == SearchType.Image or t == None) and search_settings.image_search_enabled:
# query transactions # query transactions
hits = image_search.query_images( hits = image_search.query_images(user_query, model.image_search, args.verbose)
user_query,
image_embeddings,
image_metadata_embeddings,
image_encoder,
results_count,
args.verbose)
# collate and return results # collate and return results
return image_search.collate_results( return image_search.collate_results(
hits, hits,
image_names, model.image_search.image_names,
image_config['input-directory'], image_config['input-directory'],
results_count) results_count)
@ -96,9 +79,7 @@ def regenerate(t: Optional[SearchType] = None):
if (t == SearchType.Music or t == None) and search_settings.music_search_enabled: if (t == SearchType.Music or t == None) and search_settings.music_search_enabled:
# Extract Entries, Generate Song Embeddings # Extract Entries, Generate Song Embeddings
global song_embeddings model.music_search = asymmetric.setup(
global songs
songs, song_embeddings, _, _, _ = asymmetric.setup(
song_config['input-files'], song_config['input-files'],
song_config['input-filter'], song_config['input-filter'],
pathlib.Path(song_config['compressed-jsonl']), pathlib.Path(song_config['compressed-jsonl']),
@ -108,9 +89,7 @@ def regenerate(t: Optional[SearchType] = None):
if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled: if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled:
# Extract Entries, Generate Embeddings # Extract Entries, Generate Embeddings
global transaction_embeddings model.ledger_search = symmetric_ledger.setup(
global transactions
transactions, transaction_embeddings, _, _, _ = symmetric_ledger.setup(
ledger_config['input-files'], ledger_config['input-files'],
ledger_config['input-filter'], ledger_config['input-filter'],
pathlib.Path(ledger_config['compressed-jsonl']), pathlib.Path(ledger_config['compressed-jsonl']),
@ -120,11 +99,7 @@ def regenerate(t: Optional[SearchType] = None):
if (t == SearchType.Image or t == None) and search_settings.image_search_enabled: if (t == SearchType.Image or t == None) and search_settings.image_search_enabled:
# Extract Images, Generate Embeddings # Extract Images, Generate Embeddings
global image_embeddings model.image_search = image_search.setup(
global image_metadata_embeddings
global image_names
image_names, image_embeddings, image_metadata_embeddings, _ = image_search.setup(
pathlib.Path(image_config['input-directory']), pathlib.Path(image_config['input-directory']),
pathlib.Path(image_config['embeddings-file']), pathlib.Path(image_config['embeddings-file']),
regenerate=True, regenerate=True,
@ -153,7 +128,7 @@ if __name__ == '__main__':
music_search_enabled = False music_search_enabled = False
if song_config and ('input-files' in song_config or 'input-filter' in song_config): if song_config and ('input-files' in song_config or 'input-filter' in song_config):
search_settings.music_search_enabled = True search_settings.music_search_enabled = True
songs, song_embeddings, song_encoder, song_cross_encoder, song_top_k = asymmetric.setup( model.music_search = asymmetric.setup(
song_config['input-files'], song_config['input-files'],
song_config['input-filter'], song_config['input-filter'],
pathlib.Path(song_config['compressed-jsonl']), pathlib.Path(song_config['compressed-jsonl']),
@ -165,7 +140,7 @@ if __name__ == '__main__':
ledger_config = get_from_dict(args.config, 'content-type', 'ledger') ledger_config = get_from_dict(args.config, 'content-type', 'ledger')
if ledger_config and ('input-files' in ledger_config or 'input-filter' in ledger_config): if ledger_config and ('input-files' in ledger_config or 'input-filter' in ledger_config):
search_settings.ledger_search_enabled = True search_settings.ledger_search_enabled = True
transactions, transaction_embeddings, symmetric_encoder, symmetric_cross_encoder, _ = symmetric_ledger.setup( model.ledger_search = symmetric_ledger.setup(
ledger_config['input-files'], ledger_config['input-files'],
ledger_config['input-filter'], ledger_config['input-filter'],
pathlib.Path(ledger_config['compressed-jsonl']), pathlib.Path(ledger_config['compressed-jsonl']),
@ -177,7 +152,7 @@ if __name__ == '__main__':
image_config = get_from_dict(args.config, 'content-type', 'image') image_config = get_from_dict(args.config, 'content-type', 'image')
if image_config and 'input-directory' in image_config: if image_config and 'input-directory' in image_config:
search_settings.image_search_enabled = True search_settings.image_search_enabled = True
image_names, image_embeddings, image_metadata_embeddings, image_encoder = image_search.setup( model.image_search = image_search.setup(
pathlib.Path(image_config['input-directory']), pathlib.Path(image_config['input-directory']),
pathlib.Path(image_config['embeddings-file']), pathlib.Path(image_config['embeddings-file']),
batch_size=image_config['batch-size'], batch_size=image_config['batch-size'],

View file

@ -27,6 +27,26 @@ class AsymmetricSearchModel():
self.top_k = top_k self.top_k = top_k
class LedgerSearchModel():
def __init__(self, transactions, transaction_embeddings, symmetric_encoder, symmetric_cross_encoder, top_k):
self.transactions = transactions
self.transaction_embeddings = transaction_embeddings
self.symmetric_encoder = symmetric_encoder
self.symmetric_cross_encoder = symmetric_cross_encoder
self.top_k = top_k
class ImageSearchModel():
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder):
self.image_names = image_names
self.image_embeddings = image_embeddings
self.image_metadata_embeddings = image_metadata_embeddings
self.image_encoder = image_encoder
@dataclass @dataclass
class SearchModels(): class SearchModels():
notes_search: AsymmetricSearchModel = None notes_search: AsymmetricSearchModel = None
ledger_search: LedgerSearchModel = None
music_search: AsymmetricSearchModel = None
image_search: ImageSearchModel = None