Expose API endpoint to (re-)generate embeddings from latest notes

- Provides mechanism to update notes from within application
  - Instead of having to pass the same arguments multiple times
    Pass it once (or rely on defaults when possible) and let app keep
    state and location of intermediary files

- Allows user to not have to deal with the internals of the application
  - E.g user doesn't have to specify the jsonl.gz or embeddings file path
    The app will still put those files in a default location
  - The user doesn't have to run the generation from the commandline
    as a separate step
This commit is contained in:
Debanjum Singh Solanky 2021-08-16 18:52:38 -07:00
parent 1c00c33e73
commit 04a9a6d62f

31
main.py
View file

@ -1,6 +1,8 @@
from typing import Optional from typing import Optional
from fastapi import FastAPI from fastapi import FastAPI
from search_type import asymmetric from search_type import asymmetric
from processor.org_mode.org_to_jsonl import org_to_jsonl
from utils.helpers import is_none_or_empty
import argparse import argparse
import pathlib import pathlib
import uvicorn import uvicorn
@ -20,7 +22,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = 'notes'):
if t == 'notes': if t == 'notes':
# query notes # query notes
hits = asymmetric.query_notes( hits = asymmetric.query_notes(
q, user_query,
corpus_embeddings, corpus_embeddings,
entries, entries,
bi_encoder, bi_encoder,
@ -34,22 +36,47 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = 'notes'):
return {} return {}
@app.get('/regenerate')
def regenerate():
org_to_jsonl(args.input_files, args.input_filter, args.compressed_jsonl, args.verbose)
# Extract Entries
global entries
entries = asymmetric.extract_entries(args.compressed_jsonl, args.verbose)
# Compute or Load Embeddings
global corpus_embeddings
corpus_embeddings = asymmetric.compute_embeddings(entries, bi_encoder, args.embeddings, regenerate=True, verbose=args.verbose)
if __name__ == '__main__': if __name__ == '__main__':
# Setup Argument Parser # Setup Argument Parser
parser = argparse.ArgumentParser(description="Expose API for Semantic Search") parser = argparse.ArgumentParser(description="Expose API for Semantic Search")
parser.add_argument('--input-files', '-i', nargs='*', help="List of org-mode files to process")
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for org-mode files to process")
parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path(".notes.jsonl.gz"), help="Compressed JSONL formatted notes file to compute embeddings from") parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path(".notes.jsonl.gz"), help="Compressed JSONL formatted notes file to compute embeddings from")
parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path(".notes_embeddings.pt"), help="File to save/load model embeddings to/from") parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path(".notes_embeddings.pt"), help="File to save/load model embeddings to/from")
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from org-mode files. Default: false")
parser.add_argument('--verbose', action='count', help="Show verbose conversion logs. Default: 0") parser.add_argument('--verbose', action='count', help="Show verbose conversion logs. Default: 0")
args = parser.parse_args() args = parser.parse_args()
# Input Validation
if is_none_or_empty(args.input_files) and is_none_or_empty(args.input_filter):
print("At least one of org-files or org-file-filter is required to be specified")
exit(1)
# Initialize Model # Initialize Model
bi_encoder, cross_encoder, top_k = asymmetric.initialize_model() bi_encoder, cross_encoder, top_k = asymmetric.initialize_model()
# Map notes in Org-Mode files to (compressed) JSONL formatted file
if not args.compressed_jsonl.exists() or args.regenerate:
org_to_jsonl(args.input_files, args.input_filter, args.compressed_jsonl, args.verbose)
# Extract Entries # Extract Entries
entries = asymmetric.extract_entries(args.compressed_jsonl, args.verbose) entries = asymmetric.extract_entries(args.compressed_jsonl, args.verbose)
# Compute or Load Embeddings # Compute or Load Embeddings
corpus_embeddings = asymmetric.compute_embeddings(entries, bi_encoder, args.embeddings, args.verbose) corpus_embeddings = asymmetric.compute_embeddings(entries, bi_encoder, args.embeddings, regenerate=args.regenerate, verbose=args.verbose)
# Start Application Server # Start Application Server
uvicorn.run(app) uvicorn.run(app)