mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Modularize Code. Wrap Search, Model Config in Classes. Add Tests
Details - Rename method query_* to query in search_types for standardization - Wrapping Config code in classes simplified mocking test config - Reduce args beings passed to a function by passing it as single argument wrapped in a class - Minimize setup in main.py:__main__. Put most of it into functions These functions can be mocked if required in tests later too Setup Flow: CLI_Args|Config_YAML -> (Text|Image)SearchConfig -> (Text|Image)SearchModel
This commit is contained in:
parent
f4dd9cd117
commit
d5597442f4
6 changed files with 201 additions and 154 deletions
128
src/main.py
128
src/main.py
|
@ -11,12 +11,12 @@ from fastapi import FastAPI
|
||||||
from search_type import asymmetric, symmetric_ledger, image_search
|
from search_type import asymmetric, symmetric_ledger, image_search
|
||||||
from utils.helpers import get_from_dict
|
from utils.helpers import get_from_dict
|
||||||
from utils.cli import cli
|
from utils.cli import cli
|
||||||
from utils.config import SearchType, SearchSettings, SearchModels
|
from utils.config import SearchType, SearchModels, TextSearchConfig, ImageSearchConfig, SearchConfig
|
||||||
|
|
||||||
|
|
||||||
# Application Global State
|
# Application Global State
|
||||||
model = SearchModels()
|
model = SearchModels()
|
||||||
search_settings = SearchSettings()
|
search_config = SearchConfig()
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,36 +29,36 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||||
user_query = q
|
user_query = q
|
||||||
results_count = n
|
results_count = n
|
||||||
|
|
||||||
if (t == SearchType.Notes or t == None) and search_settings.notes_search_enabled:
|
if (t == SearchType.Notes or t == None) and model.notes_search:
|
||||||
# query notes
|
# query notes
|
||||||
hits = asymmetric.query_notes(user_query, model.notes_search)
|
hits = asymmetric.query(user_query, model.notes_search)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return asymmetric.collate_results(hits, model.notes_search.entries, results_count)
|
return asymmetric.collate_results(hits, model.notes_search.entries, results_count)
|
||||||
|
|
||||||
if (t == SearchType.Music or t == None) and search_settings.music_search_enabled:
|
if (t == SearchType.Music or t == None) and model.music_search:
|
||||||
# query music library
|
# query music library
|
||||||
hits = asymmetric.query_notes(user_query, model.music_search)
|
hits = asymmetric.query(user_query, model.music_search)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return asymmetric.collate_results(hits, model.music_search.entries, results_count)
|
return asymmetric.collate_results(hits, model.music_search.entries, results_count)
|
||||||
|
|
||||||
if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled:
|
if (t == SearchType.Ledger or t == None) and model.ledger_search:
|
||||||
# query transactions
|
# query transactions
|
||||||
hits = symmetric_ledger.query_transactions(user_query, model.ledger_search)
|
hits = symmetric_ledger.query(user_query, model.ledger_search)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return symmetric_ledger.collate_results(hits, model.ledger_search.entries, results_count)
|
return symmetric_ledger.collate_results(hits, model.ledger_search.entries, results_count)
|
||||||
|
|
||||||
if (t == SearchType.Image or t == None) and search_settings.image_search_enabled:
|
if (t == SearchType.Image or t == None) and model.image_search:
|
||||||
# query transactions
|
# query transactions
|
||||||
hits = image_search.query_images(user_query, model.image_search, args.verbose)
|
hits = image_search.query(user_query, results_count, model.image_search)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return image_search.collate_results(
|
return image_search.collate_results(
|
||||||
hits,
|
hits,
|
||||||
model.image_search.image_names,
|
model.image_search.image_names,
|
||||||
image_config['input-directory'],
|
search_config.image.input_directory,
|
||||||
results_count)
|
results_count)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -67,98 +67,58 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||||
|
|
||||||
@app.get('/regenerate')
|
@app.get('/regenerate')
|
||||||
def regenerate(t: Optional[SearchType] = None):
|
def regenerate(t: Optional[SearchType] = None):
|
||||||
if (t == SearchType.Notes or t == None) and search_settings.notes_search_enabled:
|
if (t == SearchType.Notes or t == None) and search_config.notes:
|
||||||
# Extract Entries, Generate Embeddings
|
# Extract Entries, Generate Embeddings
|
||||||
models.notes_search = asymmetric.setup(
|
model.notes_search = asymmetric.setup(search_config.notes, regenerate=True)
|
||||||
org_config['input-files'],
|
|
||||||
org_config['input-filter'],
|
|
||||||
pathlib.Path(org_config['compressed-jsonl']),
|
|
||||||
pathlib.Path(org_config['embeddings-file']),
|
|
||||||
regenerate=True,
|
|
||||||
verbose=args.verbose)
|
|
||||||
|
|
||||||
if (t == SearchType.Music or t == None) and search_settings.music_search_enabled:
|
if (t == SearchType.Music or t == None) and search_config.music:
|
||||||
# Extract Entries, Generate Song Embeddings
|
# Extract Entries, Generate Song Embeddings
|
||||||
model.music_search = asymmetric.setup(
|
model.music_search = asymmetric.setup(search_config.music, regenerate=True)
|
||||||
song_config['input-files'],
|
|
||||||
song_config['input-filter'],
|
|
||||||
pathlib.Path(song_config['compressed-jsonl']),
|
|
||||||
pathlib.Path(song_config['embeddings-file']),
|
|
||||||
regenerate=True,
|
|
||||||
verbose=args.verbose)
|
|
||||||
|
|
||||||
if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled:
|
if (t == SearchType.Ledger or t == None) and search_config.ledger:
|
||||||
# Extract Entries, Generate Embeddings
|
# Extract Entries, Generate Embeddings
|
||||||
model.ledger_search = symmetric_ledger.setup(
|
model.ledger_search = symmetric_ledger.setup(search_config.ledger, regenerate=True)
|
||||||
ledger_config['input-files'],
|
|
||||||
ledger_config['input-filter'],
|
|
||||||
pathlib.Path(ledger_config['compressed-jsonl']),
|
|
||||||
pathlib.Path(ledger_config['embeddings-file']),
|
|
||||||
regenerate=True,
|
|
||||||
verbose=args.verbose)
|
|
||||||
|
|
||||||
if (t == SearchType.Image or t == None) and search_settings.image_search_enabled:
|
if (t == SearchType.Image or t == None) and search_config.image:
|
||||||
# Extract Images, Generate Embeddings
|
# Extract Images, Generate Embeddings
|
||||||
model.image_search = image_search.setup(
|
model.image_search = image_search.setup(search_config.image, regenerate=True)
|
||||||
pathlib.Path(image_config['input-directory']),
|
|
||||||
pathlib.Path(image_config['embeddings-file']),
|
|
||||||
regenerate=True,
|
|
||||||
verbose=args.verbose)
|
|
||||||
|
|
||||||
return {'status': 'ok', 'message': 'regeneration completed'}
|
return {'status': 'ok', 'message': 'regeneration completed'}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
def initialize_search(config, regenerate, verbose):
|
||||||
args = cli(sys.argv[1:])
|
model = SearchModels()
|
||||||
|
search_config = SearchConfig()
|
||||||
|
|
||||||
# Initialize Org Notes Search
|
# Initialize Org Notes Search
|
||||||
org_config = get_from_dict(args.config, 'content-type', 'org')
|
search_config.notes = TextSearchConfig.create_from_dictionary(config, ('content-type', 'org'), verbose)
|
||||||
if org_config and ('input-files' in org_config or 'input-filter' in org_config):
|
if search_config.notes:
|
||||||
search_settings.notes_search_enabled = True
|
model.notes_search = asymmetric.setup(search_config.notes, regenerate=regenerate)
|
||||||
model.notes_search = asymmetric.setup(
|
|
||||||
org_config['input-files'],
|
|
||||||
org_config['input-filter'],
|
|
||||||
pathlib.Path(org_config['compressed-jsonl']),
|
|
||||||
pathlib.Path(org_config['embeddings-file']),
|
|
||||||
args.regenerate,
|
|
||||||
args.verbose)
|
|
||||||
|
|
||||||
# Initialize Org Music Search
|
# Initialize Org Music Search
|
||||||
song_config = get_from_dict(args.config, 'content-type', 'music')
|
search_config.music = TextSearchConfig.create_from_dictionary(config, ('content-type', 'music'), verbose)
|
||||||
music_search_enabled = False
|
if search_config.music:
|
||||||
if song_config and ('input-files' in song_config or 'input-filter' in song_config):
|
model.music_search = asymmetric.setup(search_config.music, regenerate=regenerate)
|
||||||
search_settings.music_search_enabled = True
|
|
||||||
model.music_search = asymmetric.setup(
|
|
||||||
song_config['input-files'],
|
|
||||||
song_config['input-filter'],
|
|
||||||
pathlib.Path(song_config['compressed-jsonl']),
|
|
||||||
pathlib.Path(song_config['embeddings-file']),
|
|
||||||
args.regenerate,
|
|
||||||
args.verbose)
|
|
||||||
|
|
||||||
# Initialize Ledger Search
|
# Initialize Ledger Search
|
||||||
ledger_config = get_from_dict(args.config, 'content-type', 'ledger')
|
search_config.ledger = TextSearchConfig.create_from_dictionary(config, ('content-type', 'ledger'), verbose)
|
||||||
if ledger_config and ('input-files' in ledger_config or 'input-filter' in ledger_config):
|
if search_config.ledger:
|
||||||
search_settings.ledger_search_enabled = True
|
model.ledger_search = symmetric_ledger.setup(search_config.ledger, regenerate=regenerate)
|
||||||
model.ledger_search = symmetric_ledger.setup(
|
|
||||||
ledger_config['input-files'],
|
|
||||||
ledger_config['input-filter'],
|
|
||||||
pathlib.Path(ledger_config['compressed-jsonl']),
|
|
||||||
pathlib.Path(ledger_config['embeddings-file']),
|
|
||||||
args.regenerate,
|
|
||||||
args.verbose)
|
|
||||||
|
|
||||||
# Initialize Image Search
|
# Initialize Image Search
|
||||||
image_config = get_from_dict(args.config, 'content-type', 'image')
|
search_config.image = ImageSearchConfig.create_from_dictionary(config, ('content-type', 'image'), verbose)
|
||||||
if image_config and 'input-directory' in image_config:
|
if search_config.image:
|
||||||
search_settings.image_search_enabled = True
|
model.image_search = image_search.setup(search_config.image, regenerate=regenerate)
|
||||||
model.image_search = image_search.setup(
|
|
||||||
pathlib.Path(image_config['input-directory']),
|
return model, search_config
|
||||||
pathlib.Path(image_config['embeddings-file']),
|
|
||||||
batch_size=image_config['batch-size'],
|
|
||||||
regenerate=args.regenerate,
|
if __name__ == '__main__':
|
||||||
use_xmp_metadata={'yes': True, 'no': False}[image_config['use-xmp-metadata']],
|
# Load config from CLI
|
||||||
verbose=args.verbose)
|
args = cli(sys.argv[1:])
|
||||||
|
|
||||||
|
# Initialize Search from Config
|
||||||
|
model, search_config = initialize_search(args.config, args.regenerate, args.verbose)
|
||||||
|
|
||||||
# Start Application Server
|
# Start Application Server
|
||||||
uvicorn.run(app)
|
uvicorn.run(app)
|
||||||
|
|
|
@ -17,7 +17,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from utils.helpers import get_absolute_path, resolve_absolute_path
|
from utils.helpers import get_absolute_path, resolve_absolute_path
|
||||||
from processor.org_mode.org_to_jsonl import org_to_jsonl
|
from processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||||
from utils.config import AsymmetricSearchModel
|
from utils.config import TextSearchModel, TextSearchConfig
|
||||||
|
|
||||||
|
|
||||||
def initialize_model():
|
def initialize_model():
|
||||||
|
@ -66,7 +66,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v
|
||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def query_notes(raw_query: str, model: AsymmetricSearchModel):
|
def query(raw_query: str, model: TextSearchModel):
|
||||||
"Search all notes for entries that answer the query"
|
"Search all notes for entries that answer the query"
|
||||||
# Separate natural query from explicit required, blocked words filters
|
# Separate natural query from explicit required, blocked words filters
|
||||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||||
|
@ -151,21 +151,21 @@ def collate_results(hits, entries, count=5):
|
||||||
in hits[0:count]]
|
in hits[0:count]]
|
||||||
|
|
||||||
|
|
||||||
def setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate=False, verbose=False):
|
def setup(config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
|
||||||
# Initialize Model
|
# Initialize Model
|
||||||
bi_encoder, cross_encoder, top_k = initialize_model()
|
bi_encoder, cross_encoder, top_k = initialize_model()
|
||||||
|
|
||||||
# Map notes in Org-Mode files to (compressed) JSONL formatted file
|
# Map notes in Org-Mode files to (compressed) JSONL formatted file
|
||||||
if not resolve_absolute_path(compressed_jsonl).exists() or regenerate:
|
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:
|
||||||
org_to_jsonl(input_files, input_filter, compressed_jsonl, verbose)
|
org_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, config.verbose)
|
||||||
|
|
||||||
# Extract Entries
|
# Extract Entries
|
||||||
entries = extract_entries(compressed_jsonl, verbose)
|
entries = extract_entries(config.compressed_jsonl, config.verbose)
|
||||||
|
|
||||||
# Compute or Load Embeddings
|
# Compute or Load Embeddings
|
||||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, embeddings, regenerate=regenerate, verbose=verbose)
|
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=config.verbose)
|
||||||
|
|
||||||
return AsymmetricSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k)
|
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=config.verbose)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -191,7 +191,7 @@ if __name__ == '__main__':
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
# query notes
|
# query notes
|
||||||
hits = query_notes(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k)
|
hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k)
|
||||||
|
|
||||||
# render results
|
# render results
|
||||||
render_results(hits, entries, count=args.results_count)
|
render_results(hits, entries, count=args.results_count)
|
||||||
|
|
|
@ -12,6 +12,8 @@ import torch
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from utils.helpers import get_absolute_path, resolve_absolute_path
|
from utils.helpers import get_absolute_path, resolve_absolute_path
|
||||||
import utils.exiftool as exiftool
|
import utils.exiftool as exiftool
|
||||||
|
from utils.config import ImageSearchModel, ImageSearchConfig
|
||||||
|
|
||||||
|
|
||||||
def initialize_model():
|
def initialize_model():
|
||||||
# Initialize Model
|
# Initialize Model
|
||||||
|
@ -93,30 +95,31 @@ def extract_metadata(image_name, verbose=0):
|
||||||
return image_processed_metadata
|
return image_processed_metadata
|
||||||
|
|
||||||
|
|
||||||
def query_images(query, image_embeddings, image_metadata_embeddings, model, count=3, verbose=0):
|
def query(raw_query, count, model: ImageSearchModel):
|
||||||
# Set query to image content if query is a filepath
|
# Set query to image content if query is a filepath
|
||||||
if pathlib.Path(query).is_file():
|
if pathlib.Path(raw_query).is_file():
|
||||||
query_imagepath = resolve_absolute_path(pathlib.Path(query), strict=True)
|
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query), strict=True)
|
||||||
query = copy.deepcopy(Image.open(query_imagepath))
|
query = copy.deepcopy(Image.open(query_imagepath))
|
||||||
if verbose > 0:
|
if model.verbose > 0:
|
||||||
print(f"Find Images similar to Image at {query_imagepath}")
|
print(f"Find Images similar to Image at {query_imagepath}")
|
||||||
else:
|
else:
|
||||||
if verbose > 0:
|
query = raw_query
|
||||||
|
if model.verbose > 0:
|
||||||
print(f"Find Images by Text: {query}")
|
print(f"Find Images by Text: {query}")
|
||||||
|
|
||||||
# Now we encode the query (which can either be an image or a text string)
|
# Now we encode the query (which can either be an image or a text string)
|
||||||
query_embedding = model.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||||
|
|
||||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||||
image_hits = {result['corpus_id']: result['score']
|
image_hits = {result['corpus_id']: result['score']
|
||||||
for result
|
for result
|
||||||
in util.semantic_search(query_embedding, image_embeddings, top_k=count)[0]}
|
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
|
||||||
|
|
||||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
||||||
if image_metadata_embeddings:
|
if model.image_metadata_embeddings:
|
||||||
metadata_hits = {result['corpus_id']: result['score']
|
metadata_hits = {result['corpus_id']: result['score']
|
||||||
for result
|
for result
|
||||||
in util.semantic_search(query_embedding, image_metadata_embeddings, top_k=count)[0]}
|
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
|
||||||
|
|
||||||
# Sum metadata, image scores of the highest ranked images
|
# Sum metadata, image scores of the highest ranked images
|
||||||
for corpus_id, score in metadata_hits.items():
|
for corpus_id, score in metadata_hits.items():
|
||||||
|
@ -150,20 +153,30 @@ def collate_results(hits, image_names, image_directory, count=5):
|
||||||
in hits[0:count]]
|
in hits[0:count]]
|
||||||
|
|
||||||
|
|
||||||
def setup(image_directory, embeddings_file, batch_size=50, regenerate=False, use_xmp_metadata=False, verbose=0):
|
def setup(config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
|
||||||
# Initialize Model
|
# Initialize Model
|
||||||
model = initialize_model()
|
model = initialize_model()
|
||||||
|
|
||||||
# Extract Entries
|
# Extract Entries
|
||||||
image_directory = resolve_absolute_path(image_directory, strict=True)
|
image_directory = resolve_absolute_path(config.input_directory, strict=True)
|
||||||
image_names = extract_entries(image_directory, verbose)
|
image_names = extract_entries(config.input_directory, config.verbose)
|
||||||
|
|
||||||
# Compute or Load Embeddings
|
# Compute or Load Embeddings
|
||||||
embeddings_file = resolve_absolute_path(embeddings_file)
|
embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||||
image_embeddings, image_metadata_embeddings = compute_embeddings(image_names, model, embeddings_file,
|
image_embeddings, image_metadata_embeddings = compute_embeddings(
|
||||||
batch_size=batch_size, regenerate=regenerate, use_xmp_metadata=use_xmp_metadata, verbose=verbose)
|
image_names,
|
||||||
|
model,
|
||||||
|
embeddings_file,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
regenerate=regenerate,
|
||||||
|
use_xmp_metadata=config.use_xmp_metadata,
|
||||||
|
verbose=config.verbose)
|
||||||
|
|
||||||
return image_names, image_embeddings, image_metadata_embeddings, model
|
return ImageSearchModel(image_names,
|
||||||
|
image_embeddings,
|
||||||
|
image_metadata_embeddings,
|
||||||
|
model,
|
||||||
|
config.verbose)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -187,7 +200,7 @@ if __name__ == '__main__':
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
# query images
|
# query images
|
||||||
hits = query_images(user_query, image_embeddings, image_metadata_embeddings, model, args.results_count, args.verbose)
|
hits = query(user_query, image_embeddings, image_metadata_embeddings, model, args.results_count, args.verbose)
|
||||||
|
|
||||||
# render results
|
# render results
|
||||||
render_results(hits, image_names, args.image_directory, count=args.results_count)
|
render_results(hits, image_names, args.image_directory, count=args.results_count)
|
||||||
|
|
|
@ -15,6 +15,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from utils.helpers import get_absolute_path, resolve_absolute_path
|
from utils.helpers import get_absolute_path, resolve_absolute_path
|
||||||
from processor.ledger.beancount_to_jsonl import beancount_to_jsonl
|
from processor.ledger.beancount_to_jsonl import beancount_to_jsonl
|
||||||
|
from utils.config import TextSearchModel, TextSearchConfig
|
||||||
|
|
||||||
|
|
||||||
def initialize_model():
|
def initialize_model():
|
||||||
|
@ -59,7 +60,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v
|
||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def query_transactions(raw_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k=100):
|
def query(raw_query, model: TextSearchModel):
|
||||||
"Search all notes for entries that answer the query"
|
"Search all notes for entries that answer the query"
|
||||||
# Separate natural query from explicit required, blocked words filters
|
# Separate natural query from explicit required, blocked words filters
|
||||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||||
|
@ -67,20 +68,20 @@ def query_transactions(raw_query, corpus_embeddings, entries, bi_encoder, cross_
|
||||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||||
|
|
||||||
# Encode the query using the bi-encoder
|
# Encode the query using the bi-encoder
|
||||||
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
question_embedding = model.bi_encoder.encode(query, convert_to_tensor=True)
|
||||||
|
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
|
hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k)
|
||||||
hits = hits[0] # Get the hits for the first query
|
hits = hits[0] # Get the hits for the first query
|
||||||
|
|
||||||
# Filter results using explicit filters
|
# Filter results using explicit filters
|
||||||
hits = explicit_filter(hits, entries, required_words, blocked_words)
|
hits = explicit_filter(hits, model.entries, required_words, blocked_words)
|
||||||
if hits is None or len(hits) == 0:
|
if hits is None or len(hits) == 0:
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
cross_inp = [[query, entries[hit['corpus_id']]] for hit in hits]
|
cross_inp = [[query, model.entries[hit['corpus_id']]] for hit in hits]
|
||||||
cross_scores = cross_encoder.predict(cross_inp)
|
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||||
|
|
||||||
# Store cross-encoder scores in results dictionary for ranking
|
# Store cross-encoder scores in results dictionary for ranking
|
||||||
for idx in range(len(cross_scores)):
|
for idx in range(len(cross_scores)):
|
||||||
|
@ -142,21 +143,21 @@ def collate_results(hits, entries, count=5):
|
||||||
in hits[0:count]]
|
in hits[0:count]]
|
||||||
|
|
||||||
|
|
||||||
def setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate=False, verbose=False):
|
def setup(config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
|
||||||
# Initialize Model
|
# Initialize Model
|
||||||
bi_encoder, cross_encoder, top_k = initialize_model()
|
bi_encoder, cross_encoder, top_k = initialize_model()
|
||||||
|
|
||||||
# Map notes in Org-Mode files to (compressed) JSONL formatted file
|
# Map notes in Org-Mode files to (compressed) JSONL formatted file
|
||||||
if not resolve_absolute_path(compressed_jsonl).exists() or regenerate:
|
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:
|
||||||
beancount_to_jsonl(input_files, input_filter, compressed_jsonl, verbose)
|
beancount_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, config.verbose)
|
||||||
|
|
||||||
# Extract Entries
|
# Extract Entries
|
||||||
entries = extract_entries(compressed_jsonl, verbose)
|
entries = extract_entries(config.compressed_jsonl, config.verbose)
|
||||||
|
|
||||||
# Compute or Load Embeddings
|
# Compute or Load Embeddings
|
||||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, embeddings, regenerate=regenerate, verbose=verbose)
|
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=config.verbose)
|
||||||
|
|
||||||
return entries, corpus_embeddings, bi_encoder, cross_encoder, top_k
|
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=config.verbose)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -181,8 +182,8 @@ if __name__ == '__main__':
|
||||||
if user_query == "exit":
|
if user_query == "exit":
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
# query notes
|
# query
|
||||||
hits = query_transactions(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k)
|
hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k)
|
||||||
|
|
||||||
# render results
|
# render results
|
||||||
render_results(hits, entries, count=args.results_count)
|
render_results(hits, entries, count=args.results_count)
|
||||||
|
|
|
@ -6,8 +6,9 @@ import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from main import app, search_settings, model
|
from main import app, search_config, model
|
||||||
from search_type import asymmetric
|
from search_type import asymmetric
|
||||||
|
from utils.config import SearchConfig, TextSearchConfig
|
||||||
|
|
||||||
|
|
||||||
# Arrange
|
# Arrange
|
||||||
|
@ -60,14 +61,17 @@ def test_regenerate_with_valid_search_type():
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_notes_search():
|
def test_notes_search():
|
||||||
# Arrange
|
# Arrange
|
||||||
input_files = [Path('tests/data/main_readme.org'), Path('tests/data/interface_emacs_readme.org')]
|
search_config = SearchConfig()
|
||||||
input_filter = None
|
search_config.notes = TextSearchConfig(
|
||||||
compressed_jsonl = Path('tests/data/.test.jsonl.gz')
|
input_files = [Path('tests/data/main_readme.org'), Path('tests/data/interface_emacs_readme.org')],
|
||||||
embeddings = Path('tests/data/.test_embeddings.pt')
|
input_filter = None,
|
||||||
|
compressed_jsonl = Path('tests/data/.test.jsonl.gz'),
|
||||||
|
embeddings_file = Path('tests/data/.test_embeddings.pt'),
|
||||||
|
verbose = 0)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Regenerate embeddings during asymmetric setup
|
# Regenerate embeddings during asymmetric setup
|
||||||
notes_model = asymmetric.setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate=True, verbose=0)
|
notes_model = asymmetric.setup(search_config.notes, regenerate=True)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(notes_model.entries) == 10
|
assert len(notes_model.entries) == 10
|
||||||
|
@ -75,7 +79,6 @@ def test_notes_search():
|
||||||
|
|
||||||
# Arrange
|
# Arrange
|
||||||
model.notes_search = notes_model
|
model.notes_search = notes_model
|
||||||
search_settings.notes_search_enabled = True
|
|
||||||
user_query = "How to call semantic search from Emacs?"
|
user_query = "How to call semantic search from Emacs?"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -88,3 +91,30 @@ def test_notes_search():
|
||||||
assert "Semantic Search via Emacs" in search_result
|
assert "Semantic Search via Emacs" in search_result
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def test_notes_regenerate():
|
||||||
|
# Arrange
|
||||||
|
search_config = SearchConfig()
|
||||||
|
search_config.notes = TextSearchConfig(
|
||||||
|
input_files = [Path('tests/data/main_readme.org'), Path('tests/data/interface_emacs_readme.org')],
|
||||||
|
input_filter = None,
|
||||||
|
compressed_jsonl = Path('tests/data/.test.jsonl.gz'),
|
||||||
|
embeddings_file = Path('tests/data/.test_embeddings.pt'),
|
||||||
|
verbose = 0)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Regenerate embeddings during asymmetric setup
|
||||||
|
notes_model = asymmetric.setup(search_config.notes, regenerate=True)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(notes_model.entries) == 10
|
||||||
|
assert len(notes_model.corpus_embeddings) == 10
|
||||||
|
|
||||||
|
# Arrange
|
||||||
|
model.notes_search = notes_model
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get(f"/regenerate?t=notes")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
# System Packages
|
# System Packages
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
|
from utils.helpers import get_from_dict
|
||||||
|
|
||||||
|
|
||||||
class SearchType(str, Enum):
|
class SearchType(str, Enum):
|
||||||
|
@ -10,43 +14,82 @@ class SearchType(str, Enum):
|
||||||
Image = "image"
|
Image = "image"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class TextSearchModel():
|
||||||
class SearchSettings():
|
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose):
|
||||||
notes_search_enabled: bool = False
|
|
||||||
ledger_search_enabled: bool = False
|
|
||||||
music_search_enabled: bool = False
|
|
||||||
image_search_enabled: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class AsymmetricSearchModel():
|
|
||||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k):
|
|
||||||
self.entries = entries
|
self.entries = entries
|
||||||
self.corpus_embeddings = corpus_embeddings
|
self.corpus_embeddings = corpus_embeddings
|
||||||
self.bi_encoder = bi_encoder
|
self.bi_encoder = bi_encoder
|
||||||
self.cross_encoder = cross_encoder
|
self.cross_encoder = cross_encoder
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
class LedgerSearchModel():
|
|
||||||
def __init__(self, transactions, transaction_embeddings, symmetric_encoder, symmetric_cross_encoder, top_k):
|
|
||||||
self.transactions = transactions
|
|
||||||
self.transaction_embeddings = transaction_embeddings
|
|
||||||
self.symmetric_encoder = symmetric_encoder
|
|
||||||
self.symmetric_cross_encoder = symmetric_cross_encoder
|
|
||||||
self.top_k = top_k
|
|
||||||
|
|
||||||
|
|
||||||
class ImageSearchModel():
|
class ImageSearchModel():
|
||||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder):
|
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder, verbose):
|
||||||
|
self.image_encoder = image_encoder
|
||||||
self.image_names = image_names
|
self.image_names = image_names
|
||||||
self.image_embeddings = image_embeddings
|
self.image_embeddings = image_embeddings
|
||||||
self.image_metadata_embeddings = image_metadata_embeddings
|
self.image_metadata_embeddings = image_metadata_embeddings
|
||||||
self.image_encoder = image_encoder
|
self.image_encoder = image_encoder
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SearchModels():
|
class SearchModels():
|
||||||
notes_search: AsymmetricSearchModel = None
|
notes_search: TextSearchModel = None
|
||||||
ledger_search: LedgerSearchModel = None
|
ledger_search: TextSearchModel = None
|
||||||
music_search: AsymmetricSearchModel = None
|
music_search: TextSearchModel = None
|
||||||
image_search: ImageSearchModel = None
|
image_search: ImageSearchModel = None
|
||||||
|
|
||||||
|
|
||||||
|
class TextSearchConfig():
|
||||||
|
def __init__(self, input_files, input_filter, compressed_jsonl, embeddings_file, verbose):
|
||||||
|
self.input_files = input_files
|
||||||
|
self.input_filter = input_filter
|
||||||
|
self.compressed_jsonl = Path(compressed_jsonl)
|
||||||
|
self.embeddings_file = Path(embeddings_file)
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
|
||||||
|
def create_from_dictionary(config, key_tree, verbose):
|
||||||
|
text_config = get_from_dict(config, *key_tree)
|
||||||
|
search_enabled = text_config and ('input-files' in text_config or 'input-filter' in text_config)
|
||||||
|
if not search_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return TextSearchConfig(
|
||||||
|
input_files = text_config['input-files'],
|
||||||
|
input_filter = text_config['input-filter'],
|
||||||
|
compressed_jsonl = Path(text_config['compressed-jsonl']),
|
||||||
|
embeddings_file = Path(text_config['embeddings-file']),
|
||||||
|
verbose = verbose)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSearchConfig():
|
||||||
|
def __init__(self, input_directory, embeddings_file, batch_size, use_xmp_metadata, verbose):
|
||||||
|
self.input_directory = input_directory
|
||||||
|
self.embeddings_file = Path(embeddings_file)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.use_xmp_metadata = use_xmp_metadata
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
def create_from_dictionary(config, key_tree, verbose):
|
||||||
|
image_config = get_from_dict(config, *key_tree)
|
||||||
|
search_enabled = image_config and 'input-directory' in image_config
|
||||||
|
if not search_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return ImageSearchConfig(
|
||||||
|
input_directory = Path(image_config['input-directory']),
|
||||||
|
embeddings_file = Path(image_config['embeddings-file']),
|
||||||
|
batch_size = image_config['batch-size'],
|
||||||
|
use_xmp_metadata = {'yes': True, 'no': False}[image_config['use-xmp-metadata']],
|
||||||
|
verbose = verbose)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchConfig():
|
||||||
|
notes: TextSearchConfig = None
|
||||||
|
ledger: TextSearchConfig = None
|
||||||
|
music: TextSearchConfig = None
|
||||||
|
image: ImageSearchConfig = None
|
||||||
|
|
Loading…
Reference in a new issue