mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 19:03:01 +01:00
Use type specific model for other search types too. Expose them via SearchModels
- Wrap Image, Music, Ledger search into the type of SearchModel they use Similar to what was done for notes model by wrapping it's config into an AsymmetricSearchModel. - Use the uber wrapper class to expose all type specific search models
This commit is contained in:
parent
352d2930ee
commit
f4dd9cd117
2 changed files with 32 additions and 37 deletions
49
src/main.py
49
src/main.py
|
@ -38,43 +38,26 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||||
|
|
||||||
if (t == SearchType.Music or t == None) and search_settings.music_search_enabled:
|
if (t == SearchType.Music or t == None) and search_settings.music_search_enabled:
|
||||||
# query music library
|
# query music library
|
||||||
hits = asymmetric.query_notes(
|
hits = asymmetric.query_notes(user_query, model.music_search)
|
||||||
user_query,
|
|
||||||
song_embeddings,
|
|
||||||
songs,
|
|
||||||
song_encoder,
|
|
||||||
song_cross_encoder,
|
|
||||||
song_top_k)
|
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return asymmetric.collate_results(hits, songs, results_count)
|
return asymmetric.collate_results(hits, model.music_search.entries, results_count)
|
||||||
|
|
||||||
if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled:
|
if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled:
|
||||||
# query transactions
|
# query transactions
|
||||||
hits = symmetric_ledger.query_transactions(
|
hits = symmetric_ledger.query_transactions(user_query, model.ledger_search)
|
||||||
user_query,
|
|
||||||
transaction_embeddings,
|
|
||||||
transactions,
|
|
||||||
symmetric_encoder,
|
|
||||||
symmetric_cross_encoder)
|
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return symmetric_ledger.collate_results(hits, transactions, results_count)
|
return symmetric_ledger.collate_results(hits, model.ledger_search.entries, results_count)
|
||||||
|
|
||||||
if (t == SearchType.Image or t == None) and search_settings.image_search_enabled:
|
if (t == SearchType.Image or t == None) and search_settings.image_search_enabled:
|
||||||
# query transactions
|
# query transactions
|
||||||
hits = image_search.query_images(
|
hits = image_search.query_images(user_query, model.image_search, args.verbose)
|
||||||
user_query,
|
|
||||||
image_embeddings,
|
|
||||||
image_metadata_embeddings,
|
|
||||||
image_encoder,
|
|
||||||
results_count,
|
|
||||||
args.verbose)
|
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return image_search.collate_results(
|
return image_search.collate_results(
|
||||||
hits,
|
hits,
|
||||||
image_names,
|
model.image_search.image_names,
|
||||||
image_config['input-directory'],
|
image_config['input-directory'],
|
||||||
results_count)
|
results_count)
|
||||||
|
|
||||||
|
@ -96,9 +79,7 @@ def regenerate(t: Optional[SearchType] = None):
|
||||||
|
|
||||||
if (t == SearchType.Music or t == None) and search_settings.music_search_enabled:
|
if (t == SearchType.Music or t == None) and search_settings.music_search_enabled:
|
||||||
# Extract Entries, Generate Song Embeddings
|
# Extract Entries, Generate Song Embeddings
|
||||||
global song_embeddings
|
model.music_search = asymmetric.setup(
|
||||||
global songs
|
|
||||||
songs, song_embeddings, _, _, _ = asymmetric.setup(
|
|
||||||
song_config['input-files'],
|
song_config['input-files'],
|
||||||
song_config['input-filter'],
|
song_config['input-filter'],
|
||||||
pathlib.Path(song_config['compressed-jsonl']),
|
pathlib.Path(song_config['compressed-jsonl']),
|
||||||
|
@ -108,9 +89,7 @@ def regenerate(t: Optional[SearchType] = None):
|
||||||
|
|
||||||
if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled:
|
if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled:
|
||||||
# Extract Entries, Generate Embeddings
|
# Extract Entries, Generate Embeddings
|
||||||
global transaction_embeddings
|
model.ledger_search = symmetric_ledger.setup(
|
||||||
global transactions
|
|
||||||
transactions, transaction_embeddings, _, _, _ = symmetric_ledger.setup(
|
|
||||||
ledger_config['input-files'],
|
ledger_config['input-files'],
|
||||||
ledger_config['input-filter'],
|
ledger_config['input-filter'],
|
||||||
pathlib.Path(ledger_config['compressed-jsonl']),
|
pathlib.Path(ledger_config['compressed-jsonl']),
|
||||||
|
@ -120,11 +99,7 @@ def regenerate(t: Optional[SearchType] = None):
|
||||||
|
|
||||||
if (t == SearchType.Image or t == None) and search_settings.image_search_enabled:
|
if (t == SearchType.Image or t == None) and search_settings.image_search_enabled:
|
||||||
# Extract Images, Generate Embeddings
|
# Extract Images, Generate Embeddings
|
||||||
global image_embeddings
|
model.image_search = image_search.setup(
|
||||||
global image_metadata_embeddings
|
|
||||||
global image_names
|
|
||||||
|
|
||||||
image_names, image_embeddings, image_metadata_embeddings, _ = image_search.setup(
|
|
||||||
pathlib.Path(image_config['input-directory']),
|
pathlib.Path(image_config['input-directory']),
|
||||||
pathlib.Path(image_config['embeddings-file']),
|
pathlib.Path(image_config['embeddings-file']),
|
||||||
regenerate=True,
|
regenerate=True,
|
||||||
|
@ -153,7 +128,7 @@ if __name__ == '__main__':
|
||||||
music_search_enabled = False
|
music_search_enabled = False
|
||||||
if song_config and ('input-files' in song_config or 'input-filter' in song_config):
|
if song_config and ('input-files' in song_config or 'input-filter' in song_config):
|
||||||
search_settings.music_search_enabled = True
|
search_settings.music_search_enabled = True
|
||||||
songs, song_embeddings, song_encoder, song_cross_encoder, song_top_k = asymmetric.setup(
|
model.music_search = asymmetric.setup(
|
||||||
song_config['input-files'],
|
song_config['input-files'],
|
||||||
song_config['input-filter'],
|
song_config['input-filter'],
|
||||||
pathlib.Path(song_config['compressed-jsonl']),
|
pathlib.Path(song_config['compressed-jsonl']),
|
||||||
|
@ -165,7 +140,7 @@ if __name__ == '__main__':
|
||||||
ledger_config = get_from_dict(args.config, 'content-type', 'ledger')
|
ledger_config = get_from_dict(args.config, 'content-type', 'ledger')
|
||||||
if ledger_config and ('input-files' in ledger_config or 'input-filter' in ledger_config):
|
if ledger_config and ('input-files' in ledger_config or 'input-filter' in ledger_config):
|
||||||
search_settings.ledger_search_enabled = True
|
search_settings.ledger_search_enabled = True
|
||||||
transactions, transaction_embeddings, symmetric_encoder, symmetric_cross_encoder, _ = symmetric_ledger.setup(
|
model.ledger_search = symmetric_ledger.setup(
|
||||||
ledger_config['input-files'],
|
ledger_config['input-files'],
|
||||||
ledger_config['input-filter'],
|
ledger_config['input-filter'],
|
||||||
pathlib.Path(ledger_config['compressed-jsonl']),
|
pathlib.Path(ledger_config['compressed-jsonl']),
|
||||||
|
@ -177,7 +152,7 @@ if __name__ == '__main__':
|
||||||
image_config = get_from_dict(args.config, 'content-type', 'image')
|
image_config = get_from_dict(args.config, 'content-type', 'image')
|
||||||
if image_config and 'input-directory' in image_config:
|
if image_config and 'input-directory' in image_config:
|
||||||
search_settings.image_search_enabled = True
|
search_settings.image_search_enabled = True
|
||||||
image_names, image_embeddings, image_metadata_embeddings, image_encoder = image_search.setup(
|
model.image_search = image_search.setup(
|
||||||
pathlib.Path(image_config['input-directory']),
|
pathlib.Path(image_config['input-directory']),
|
||||||
pathlib.Path(image_config['embeddings-file']),
|
pathlib.Path(image_config['embeddings-file']),
|
||||||
batch_size=image_config['batch-size'],
|
batch_size=image_config['batch-size'],
|
||||||
|
|
|
@ -27,6 +27,26 @@ class AsymmetricSearchModel():
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
||||||
|
|
||||||
|
class LedgerSearchModel():
|
||||||
|
def __init__(self, transactions, transaction_embeddings, symmetric_encoder, symmetric_cross_encoder, top_k):
|
||||||
|
self.transactions = transactions
|
||||||
|
self.transaction_embeddings = transaction_embeddings
|
||||||
|
self.symmetric_encoder = symmetric_encoder
|
||||||
|
self.symmetric_cross_encoder = symmetric_cross_encoder
|
||||||
|
self.top_k = top_k
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSearchModel():
|
||||||
|
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder):
|
||||||
|
self.image_names = image_names
|
||||||
|
self.image_embeddings = image_embeddings
|
||||||
|
self.image_metadata_embeddings = image_metadata_embeddings
|
||||||
|
self.image_encoder = image_encoder
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SearchModels():
|
class SearchModels():
|
||||||
notes_search: AsymmetricSearchModel = None
|
notes_search: AsymmetricSearchModel = None
|
||||||
|
ledger_search: LedgerSearchModel = None
|
||||||
|
music_search: AsymmetricSearchModel = None
|
||||||
|
image_search: ImageSearchModel = None
|
||||||
|
|
Loading…
Reference in a new issue