diff --git a/src/main.py b/src/main.py index 002300f9..3c9fa2a5 100644 --- a/src/main.py +++ b/src/main.py @@ -8,7 +8,7 @@ import uvicorn from fastapi import FastAPI # Internal Packages -from search_type import asymmetric +from search_type import asymmetric, symmetric_ledger from utils.helpers import get_from_dict from utils.cli import cli @@ -38,6 +38,18 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None): # collate and return results return asymmetric.collate_results(hits, entries, results_count) + if (t == 'ledger' or t == None) and ledger_search_enabled: + # query transactions + hits = symmetric_ledger.query_transactions( + user_query, + transaction_embeddings, + transactions, + symmetric_encoder, + symmetric_cross_encoder) + + # collate and return results + return symmetric_ledger.collate_results(hits, transactions, results_count) + else: return {} @@ -56,16 +68,28 @@ def regenerate(t: Optional[str] = None): regenerate=True, verbose=args.verbose) + if (t == 'ledger' or t == None) and ledger_search_enabled: + # Extract Entries, Generate Embeddings + global transaction_embeddings + global transactions + transactions, transaction_embeddings, _, _, _ = symmetric_ledger.setup( + ledger_config['input-files'], + ledger_config['input-filter'], + pathlib.Path(ledger_config['compressed-jsonl']), + pathlib.Path(ledger_config['embeddings-file']), + regenerate=True, + verbose=args.verbose) return {'status': 'ok', 'message': 'regeneration completed'} if __name__ == '__main__': args = cli(sys.argv[1:]) - org_config = get_from_dict(args.config, 'content-type', 'org') + # Initialize Org Notes Search + org_config = get_from_dict(args.config, 'content-type', 'org') notes_search_enabled = False - if 'input-files' in org_config or 'input-filter' in org_config: + if org_config and ('input-files' in org_config or 'input-filter' in org_config): notes_search_enabled = True entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = asymmetric.setup( org_config['input-files'], @@ -75,6 +99,18 @@ if __name__ == '__main__': args.regenerate, args.verbose) + # Initialize Ledger Search + ledger_config = get_from_dict(args.config, 'content-type', 'ledger') + ledger_search_enabled = False + if ledger_config and ('input-files' in ledger_config or 'input-filter' in ledger_config): + ledger_search_enabled = True + transactions, transaction_embeddings, symmetric_encoder, symmetric_cross_encoder, _ = symmetric_ledger.setup( + ledger_config['input-files'], + ledger_config['input-filter'], + pathlib.Path(ledger_config['compressed-jsonl']), + pathlib.Path(ledger_config['embeddings-file']), + args.regenerate, + args.verbose) # Start Application Server uvicorn.run(app) diff --git a/src/processor/ledger/__init__.py b/src/processor/ledger/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/processor/ledger/beancount_to_jsonl.py b/src/processor/ledger/beancount_to_jsonl.py new file mode 100644 index 00000000..829df471 --- /dev/null +++ b/src/processor/ledger/beancount_to_jsonl.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 + +# Standard Packages +import json +import argparse +import pathlib +import glob +import gzip + +# Internal Packages +from processor.org_mode import orgnode +from utils.helpers import get_absolute_path, is_none_or_empty + + +# Define Functions +def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file, verbose=0): + # Input Validation + if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter): + print("At least one of beancount-files or beancount-file-filter is required to be specified") + exit(1) + + # Get Beancount Files to Process + beancount_files = get_beancount_files(beancount_files, beancount_file_filter, verbose) + + # Extract Entries from specified Beancount files + entries = extract_beancount_entries(beancount_files) + + # Process Each Entry from All Notes Files + jsonl_data = convert_beancount_entries_to_jsonl(entries, verbose=verbose) + + # Compress JSONL formatted Data + if output_file.suffix == ".gz": + compress_jsonl_data(jsonl_data, output_file, verbose=verbose) + elif output_file.suffix == ".jsonl": + dump_jsonl(jsonl_data, output_file, verbose=verbose) + + return entries + + +def dump_jsonl(jsonl_data, output_path, verbose=0): + "Write List of JSON objects to JSON line file" + with open(get_absolute_path(output_path), 'w', encoding='utf-8') as f: + f.write(jsonl_data) + + if verbose > 0: + print(f'Wrote {len(jsonl_data)} lines to jsonl at {output_path}') + + +def compress_jsonl_data(jsonl_data, output_path, verbose=0): + with gzip.open(get_absolute_path(output_path), 'wt') as gzip_file: + gzip_file.write(jsonl_data) + + if verbose > 0: + print(f'Wrote {len(jsonl_data)} lines to gzip compressed jsonl at {output_path}') + + +def load_jsonl(input_path, verbose=0): + "Read List of JSON objects from JSON line file" + data = [] + with open(get_absolute_path(input_path), 'r', encoding='utf-8') as f: + for line in f: + data.append(json.loads(line.rstrip('\n|\r'))) + + if verbose > 0: + print(f'Loaded {len(data)} records from {input_path}') + + return data + + +def get_beancount_files(beancount_files=None, beancount_file_filter=None, verbose=0): + "Get Beancount files to process" + absolute_beancount_files, filtered_beancount_files = set(), set() + if beancount_files: + absolute_beancount_files = {get_absolute_path(beancount_file) + for beancount_file + in beancount_files} + if beancount_file_filter: + filtered_beancount_files = set(glob.glob(get_absolute_path(beancount_file_filter))) + + all_beancount_files = absolute_beancount_files | filtered_beancount_files + + files_with_non_beancount_extensions = {beancount_file for beancount_file in all_beancount_files if not beancount_file.endswith(".bean")} + if any(files_with_non_beancount_extensions): + print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}") + + if verbose > 0: + print(f'Processing files: {all_beancount_files}') + + return all_beancount_files + + +def extract_beancount_entries(beancount_files): + "Extract entries from specified Beancount files" + entries = [] + for beancount_file in beancount_files: + with open(beancount_file) as f: + entries.extend( + f.read().split('\n\n')) + + return entries + + +def convert_beancount_entries_to_jsonl(entries, verbose=0): + "Convert each Beancount transaction to JSON and collate as JSONL" + jsonl = '' + for entry in entries: + entry_dict = {'Title': entry} + # Convert Dictionary to JSON and Append to JSONL string + jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n' + + if verbose > 0: + print(f"Converted {len(entries)} to jsonl format") + + return jsonl + + +if __name__ == '__main__': + # Setup Argument Parser + parser = argparse.ArgumentParser(description="Map Beancount transactions into (compressed) JSONL format") + parser.add_argument('--output-file', '-o', type=pathlib.Path, required=True, help="Output file for (compressed) JSONL formatted transactions. Expected file extensions: jsonl or jsonl.gz") + parser.add_argument('--input-files', '-i', nargs='*', help="List of beancount files to process") + parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for beancount files to process") + parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs, Default: 0") + args = parser.parse_args() + + # Map transactions in beancount files to (compressed) JSONL formatted file + beancount_to_jsonl(args.input_files, args.input_filter, args.output_file, args.verbose) diff --git a/src/search_type/symmetric_ledger.py b/src/search_type/symmetric_ledger.py new file mode 100644 index 00000000..6c0bc4c1 --- /dev/null +++ b/src/search_type/symmetric_ledger.py @@ -0,0 +1,187 @@ +# Standard Packages +import json +import time +import gzip +import os +import sys +import re +import argparse +import pathlib + +# External Packages +import torch +from sentence_transformers import SentenceTransformer, CrossEncoder, util + +# Internal Packages +from utils.helpers import get_absolute_path +from processor.ledger.beancount_to_jsonl import beancount_to_jsonl + + +def initialize_model(): + "Initialize model for symetric semantic search. That is, where query of similar size to results" + bi_encoder = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2') # The encoder encodes all entries to use for semantic search + top_k = 100 # Number of entries we want to retrieve with the bi-encoder + cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality + return bi_encoder, cross_encoder, top_k + + +def extract_entries(notesfile, verbose=0): + "Load entries from compressed jsonl" + entries = [] + with gzip.open(get_absolute_path(notesfile), 'rt', encoding='utf8') as jsonl: + for line in jsonl: + note = json.loads(line.strip()) + + note_string = f'{note["Title"]} \t {note["Tags"] if "Tags" in note else ""} \n {note["Body"] if "Body" in note else ""}' + entries.extend([note_string]) + + if verbose > 0: + print(f"Loaded {len(entries)} entries from {notesfile}") + + return entries + + +def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0): + "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" + # Load pre-computed embeddings from file if exists + if embeddings_file.exists() and not regenerate: + corpus_embeddings = torch.load(get_absolute_path(embeddings_file)) + if verbose > 0: + print(f"Loaded embeddings from {embeddings_file}") + + else: # Else compute the corpus_embeddings from scratch, which can take a while + corpus_embeddings = bi_encoder.encode(entries, convert_to_tensor=True, show_progress_bar=True) + torch.save(corpus_embeddings, get_absolute_path(embeddings_file)) + if verbose > 0: + print(f"Computed embeddings and save them to {embeddings_file}") + + return corpus_embeddings + + +def query_transactions(raw_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k=100): + "Search all notes for entries that answer the query" + # Separate natural query from explicit required, blocked words filters + query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) + required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) + blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) + + # Encode the query using the bi-encoder + question_embedding = bi_encoder.encode(query, convert_to_tensor=True) + + # Find relevant entries for the query + hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k) + hits = hits[0] # Get the hits for the first query + + # Filter results using explicit filters + hits = explicit_filter(hits, entries, required_words, blocked_words) + if hits is None or len(hits) == 0: + return hits + + # Score all retrieved entries using the cross-encoder + cross_inp = [[query, entries[hit['corpus_id']]] for hit in hits] + cross_scores = cross_encoder.predict(cross_inp) + + # Store cross-encoder scores in results dictionary for ranking + for idx in range(len(cross_scores)): + hits[idx]['cross-score'] = cross_scores[idx] + + # Order results by cross encoder score followed by biencoder score + hits.sort(key=lambda x: x['score'], reverse=True) # sort by biencoder score + hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross encoder score + + return hits + + +def explicit_filter(hits, entries, required_words, blocked_words): + hits_by_word_set = [(set(word.lower() + for word + in re.split( + ',|\.| |\]|\[\(|\)|\{|\}', + entries[hit['corpus_id']]) + if word != ""), + hit) + for hit in hits] + + if len(required_words) == 0 and len(blocked_words) == 0: + return hits + if len(required_words) > 0: + return [hit for (words_in_entry, hit) in hits_by_word_set + if required_words.intersection(words_in_entry) and not blocked_words.intersection(words_in_entry)] + if len(blocked_words) > 0: + return [hit for (words_in_entry, hit) in hits_by_word_set + if not blocked_words.intersection(words_in_entry)] + return hits + + +def render_results(hits, entries, count=5, display_biencoder_results=False): + "Render the Results returned by Search for the Query" + if display_biencoder_results: + # Output of top hits from bi-encoder + print("\n-------------------------\n") + print(f"Top-{count} Bi-Encoder Retrieval hits") + hits = sorted(hits, key=lambda x: x['score'], reverse=True) + for hit in hits[0:count]: + print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]}") + + # Output of top hits from re-ranker + print("\n-------------------------\n") + print(f"Top-{count} Cross-Encoder Re-ranker hits") + hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) + for hit in hits[0:count]: + print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]}") + + +def collate_results(hits, entries, count=5): + return [ + { + "Entry": entries[hit['corpus_id']], + "Score": f"{hit['cross-score']:.3f}" + } + for hit + in hits[0:count]] + + +def setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate=False, verbose=False): + # Initialize Model + bi_encoder, cross_encoder, top_k = initialize_model() + + # Map notes in Org-Mode files to (compressed) JSONL formatted file + if not compressed_jsonl.exists() or regenerate: + beancount_to_jsonl(input_files, input_filter, compressed_jsonl, verbose) + + # Extract Entries + entries = extract_entries(compressed_jsonl, verbose) + + # Compute or Load Embeddings + corpus_embeddings = compute_embeddings(entries, bi_encoder, embeddings, regenerate=regenerate, verbose=verbose) + + return entries, corpus_embeddings, bi_encoder, cross_encoder, top_k + + +if __name__ == '__main__': + # Setup Argument Parser + parser = argparse.ArgumentParser(description="Map Beancount transactions into (compressed) JSONL format") + parser.add_argument('--input-files', '-i', nargs='*', help="List of Beancount files to process") + parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for Beancount files to process") + parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path(".transactions.jsonl.gz"), help="Compressed JSONL formatted transactions file to compute embeddings from") + parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path(".transaction_embeddings.pt"), help="File to save/load model embeddings to/from") + parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from Beancount files. Default: false") + parser.add_argument('--results-count', '-n', default=5, type=int, help="Number of results to render. Default: 5") + parser.add_argument('--interactive', action='store_true', default=False, help="Interactive mode allows user to run queries on the model. Default: true") + parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0") + args = parser.parse_args() + + entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate, args.verbose) + + # Run User Queries on Entries in Interactive Mode + while args.interactive: + # get query from user + user_query = input("Enter your query: ") + if user_query == "exit": + exit(0) + + # query notes + hits = query_transactions(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k) + + # render results + render_results(hits, entries, count=args.results_count) diff --git a/src/utils/cli.py b/src/utils/cli.py index d70cb168..cd1519d3 100644 --- a/src/utils/cli.py +++ b/src/utils/cli.py @@ -48,6 +48,11 @@ default_config = { { 'compressed-jsonl': '.notes.jsonl.gz', 'embeddings-file': '.note_embeddings.pt' + }, + 'ledger': + { + 'compressed-jsonl': '.transactions.jsonl.gz', + 'embeddings-file': '.transaction_embeddings.pt' } }, 'search-type':