mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 09:25:06 +01:00
Use type specific model for other search types too. Expose them via SearchModels
- Wrap Image, Music, Ledger search into the type of SearchModel they use Similar to what was done for notes model by wrapping it's config into an AsymmetricSearchModel. - Use the uber wrapper class to expose all type specific search models
This commit is contained in:
parent
352d2930ee
commit
f4dd9cd117
2 changed files with 32 additions and 37 deletions
49
src/main.py
49
src/main.py
|
@ -38,43 +38,26 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
|||
|
||||
if (t == SearchType.Music or t == None) and search_settings.music_search_enabled:
|
||||
# query music library
|
||||
hits = asymmetric.query_notes(
|
||||
user_query,
|
||||
song_embeddings,
|
||||
songs,
|
||||
song_encoder,
|
||||
song_cross_encoder,
|
||||
song_top_k)
|
||||
hits = asymmetric.query_notes(user_query, model.music_search)
|
||||
|
||||
# collate and return results
|
||||
return asymmetric.collate_results(hits, songs, results_count)
|
||||
return asymmetric.collate_results(hits, model.music_search.entries, results_count)
|
||||
|
||||
if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled:
|
||||
# query transactions
|
||||
hits = symmetric_ledger.query_transactions(
|
||||
user_query,
|
||||
transaction_embeddings,
|
||||
transactions,
|
||||
symmetric_encoder,
|
||||
symmetric_cross_encoder)
|
||||
hits = symmetric_ledger.query_transactions(user_query, model.ledger_search)
|
||||
|
||||
# collate and return results
|
||||
return symmetric_ledger.collate_results(hits, transactions, results_count)
|
||||
return symmetric_ledger.collate_results(hits, model.ledger_search.entries, results_count)
|
||||
|
||||
if (t == SearchType.Image or t == None) and search_settings.image_search_enabled:
|
||||
# query transactions
|
||||
hits = image_search.query_images(
|
||||
user_query,
|
||||
image_embeddings,
|
||||
image_metadata_embeddings,
|
||||
image_encoder,
|
||||
results_count,
|
||||
args.verbose)
|
||||
hits = image_search.query_images(user_query, model.image_search, args.verbose)
|
||||
|
||||
# collate and return results
|
||||
return image_search.collate_results(
|
||||
hits,
|
||||
image_names,
|
||||
model.image_search.image_names,
|
||||
image_config['input-directory'],
|
||||
results_count)
|
||||
|
||||
|
@ -96,9 +79,7 @@ def regenerate(t: Optional[SearchType] = None):
|
|||
|
||||
if (t == SearchType.Music or t == None) and search_settings.music_search_enabled:
|
||||
# Extract Entries, Generate Song Embeddings
|
||||
global song_embeddings
|
||||
global songs
|
||||
songs, song_embeddings, _, _, _ = asymmetric.setup(
|
||||
model.music_search = asymmetric.setup(
|
||||
song_config['input-files'],
|
||||
song_config['input-filter'],
|
||||
pathlib.Path(song_config['compressed-jsonl']),
|
||||
|
@ -108,9 +89,7 @@ def regenerate(t: Optional[SearchType] = None):
|
|||
|
||||
if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled:
|
||||
# Extract Entries, Generate Embeddings
|
||||
global transaction_embeddings
|
||||
global transactions
|
||||
transactions, transaction_embeddings, _, _, _ = symmetric_ledger.setup(
|
||||
model.ledger_search = symmetric_ledger.setup(
|
||||
ledger_config['input-files'],
|
||||
ledger_config['input-filter'],
|
||||
pathlib.Path(ledger_config['compressed-jsonl']),
|
||||
|
@ -120,11 +99,7 @@ def regenerate(t: Optional[SearchType] = None):
|
|||
|
||||
if (t == SearchType.Image or t == None) and search_settings.image_search_enabled:
|
||||
# Extract Images, Generate Embeddings
|
||||
global image_embeddings
|
||||
global image_metadata_embeddings
|
||||
global image_names
|
||||
|
||||
image_names, image_embeddings, image_metadata_embeddings, _ = image_search.setup(
|
||||
model.image_search = image_search.setup(
|
||||
pathlib.Path(image_config['input-directory']),
|
||||
pathlib.Path(image_config['embeddings-file']),
|
||||
regenerate=True,
|
||||
|
@ -153,7 +128,7 @@ if __name__ == '__main__':
|
|||
music_search_enabled = False
|
||||
if song_config and ('input-files' in song_config or 'input-filter' in song_config):
|
||||
search_settings.music_search_enabled = True
|
||||
songs, song_embeddings, song_encoder, song_cross_encoder, song_top_k = asymmetric.setup(
|
||||
model.music_search = asymmetric.setup(
|
||||
song_config['input-files'],
|
||||
song_config['input-filter'],
|
||||
pathlib.Path(song_config['compressed-jsonl']),
|
||||
|
@ -165,7 +140,7 @@ if __name__ == '__main__':
|
|||
ledger_config = get_from_dict(args.config, 'content-type', 'ledger')
|
||||
if ledger_config and ('input-files' in ledger_config or 'input-filter' in ledger_config):
|
||||
search_settings.ledger_search_enabled = True
|
||||
transactions, transaction_embeddings, symmetric_encoder, symmetric_cross_encoder, _ = symmetric_ledger.setup(
|
||||
model.ledger_search = symmetric_ledger.setup(
|
||||
ledger_config['input-files'],
|
||||
ledger_config['input-filter'],
|
||||
pathlib.Path(ledger_config['compressed-jsonl']),
|
||||
|
@ -177,7 +152,7 @@ if __name__ == '__main__':
|
|||
image_config = get_from_dict(args.config, 'content-type', 'image')
|
||||
if image_config and 'input-directory' in image_config:
|
||||
search_settings.image_search_enabled = True
|
||||
image_names, image_embeddings, image_metadata_embeddings, image_encoder = image_search.setup(
|
||||
model.image_search = image_search.setup(
|
||||
pathlib.Path(image_config['input-directory']),
|
||||
pathlib.Path(image_config['embeddings-file']),
|
||||
batch_size=image_config['batch-size'],
|
||||
|
|
|
@ -27,6 +27,26 @@ class AsymmetricSearchModel():
|
|||
self.top_k = top_k
|
||||
|
||||
|
||||
class LedgerSearchModel():
|
||||
def __init__(self, transactions, transaction_embeddings, symmetric_encoder, symmetric_cross_encoder, top_k):
|
||||
self.transactions = transactions
|
||||
self.transaction_embeddings = transaction_embeddings
|
||||
self.symmetric_encoder = symmetric_encoder
|
||||
self.symmetric_cross_encoder = symmetric_cross_encoder
|
||||
self.top_k = top_k
|
||||
|
||||
|
||||
class ImageSearchModel():
|
||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder):
|
||||
self.image_names = image_names
|
||||
self.image_embeddings = image_embeddings
|
||||
self.image_metadata_embeddings = image_metadata_embeddings
|
||||
self.image_encoder = image_encoder
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchModels():
|
||||
notes_search: AsymmetricSearchModel = None
|
||||
ledger_search: LedgerSearchModel = None
|
||||
music_search: AsymmetricSearchModel = None
|
||||
image_search: ImageSearchModel = None
|
||||
|
|
Loading…
Reference in a new issue