2021-08-17 12:59:58 +02:00
|
|
|
# Standard Packages
|
|
|
|
import sys
|
|
|
|
import argparse
|
|
|
|
import pathlib
|
2021-08-16 02:50:08 +02:00
|
|
|
from typing import Optional
|
2021-08-17 12:59:58 +02:00
|
|
|
|
|
|
|
# External Packages
|
|
|
|
import uvicorn
|
2021-08-22 03:47:55 +02:00
|
|
|
import yaml
|
2021-08-16 02:50:08 +02:00
|
|
|
from fastapi import FastAPI
|
2021-08-17 12:59:58 +02:00
|
|
|
|
|
|
|
# Internal Packages
|
2021-08-17 01:04:45 +02:00
|
|
|
from search_type import asymmetric
|
2021-08-17 03:52:38 +02:00
|
|
|
from processor.org_mode.org_to_jsonl import org_to_jsonl
|
2021-08-22 03:47:55 +02:00
|
|
|
from utils.helpers import is_none_or_empty, get_absolute_path, get_from_dict, merge_dicts
|
2021-08-17 12:59:58 +02:00
|
|
|
|
2021-08-16 02:50:08 +02:00
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
|
|
@app.get('/search')
|
2021-08-17 13:36:48 +02:00
|
|
|
def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
2021-08-16 02:50:08 +02:00
|
|
|
if q is None or q == '':
|
|
|
|
print(f'No query param (q) passed in API call to initiate search')
|
|
|
|
return {}
|
|
|
|
|
|
|
|
user_query = q
|
|
|
|
results_count = n
|
|
|
|
|
2021-08-22 03:47:55 +02:00
|
|
|
if (t == 'notes' or t == None) and notes_search_enabled:
|
2021-08-16 02:50:08 +02:00
|
|
|
# query notes
|
2021-08-17 01:52:48 +02:00
|
|
|
hits = asymmetric.query_notes(
|
2021-08-17 03:52:38 +02:00
|
|
|
user_query,
|
2021-08-17 01:52:48 +02:00
|
|
|
corpus_embeddings,
|
|
|
|
entries,
|
|
|
|
bi_encoder,
|
|
|
|
cross_encoder,
|
|
|
|
top_k)
|
2021-08-16 02:50:08 +02:00
|
|
|
|
|
|
|
# collate and return results
|
2021-08-16 04:09:50 +02:00
|
|
|
return asymmetric.collate_results(hits, entries, results_count)
|
2021-08-16 02:50:08 +02:00
|
|
|
|
|
|
|
else:
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
2021-08-17 03:52:38 +02:00
|
|
|
@app.get('/regenerate')
|
2021-08-17 13:36:48 +02:00
|
|
|
def regenerate(t: Optional[str] = None):
|
2021-08-22 03:47:55 +02:00
|
|
|
if (t == 'notes' or t == None) and notes_search_enabled:
|
2021-08-17 13:36:48 +02:00
|
|
|
# Extract Entries, Generate Embeddings
|
|
|
|
global corpus_embeddings
|
|
|
|
global entries
|
2021-08-22 03:47:55 +02:00
|
|
|
entries, corpus_embeddings, _, _, _ = asymmetric.setup(
|
|
|
|
org_config['input-files'],
|
|
|
|
org_config['input-filter'],
|
|
|
|
pathlib.Path(org_config['compressed-jsonl']),
|
|
|
|
pathlib.Path(org_config['embeddings-file']),
|
|
|
|
regenerate=True,
|
|
|
|
verbose=args.verbose)
|
|
|
|
|
2021-08-17 08:47:33 +02:00
|
|
|
|
|
|
|
return {'status': 'ok', 'message': 'regeneration completed'}
|
2021-08-17 03:52:38 +02:00
|
|
|
|
|
|
|
|
2021-08-17 13:00:45 +02:00
|
|
|
def cli(args=None):
|
2021-08-22 03:47:55 +02:00
|
|
|
if is_none_or_empty(args):
|
2021-08-17 13:00:45 +02:00
|
|
|
args = sys.argv[1:]
|
|
|
|
|
|
|
|
# Setup Argument Parser for the Commandline Interface
|
2021-08-16 02:50:08 +02:00
|
|
|
parser = argparse.ArgumentParser(description="Expose API for Semantic Search")
|
2021-08-22 03:47:55 +02:00
|
|
|
parser.add_argument('--org-files', '-i', nargs='*', help="List of org-mode files to process")
|
|
|
|
parser.add_argument('--org-filter', type=str, default=None, help="Regex filter for org-mode files to process")
|
|
|
|
parser.add_argument('--config-file', '-c', type=pathlib.Path, help="YAML file with user configuration")
|
|
|
|
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate model embeddings from source files. Default: false")
|
2021-08-17 13:00:45 +02:00
|
|
|
parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs. Default: 0")
|
2021-08-22 03:47:55 +02:00
|
|
|
args = parser.parse_args(args)
|
|
|
|
|
|
|
|
if not (args.config_file or args.org_files):
|
|
|
|
print(f"Require at least 1 of --org-file, --org-filter or --config-file flags to be passed from commandline")
|
|
|
|
exit(1)
|
|
|
|
|
|
|
|
# Config Priority: Cmd Args > Config File > Default Config
|
|
|
|
args.config = default_config
|
|
|
|
if args.config_file and args.config_file.exists():
|
|
|
|
with open(get_absolute_path(args.config_file), 'r', encoding='utf-8') as config_file:
|
|
|
|
config_from_file = yaml.safe_load(config_file)
|
|
|
|
args.config = merge_dicts(priority_dict=config_from_file, default_dict=args.config)
|
|
|
|
|
|
|
|
if args.org_files:
|
|
|
|
args.config['content-type']['org']['input-files'] = args.org_files
|
|
|
|
|
|
|
|
if args.org_filter:
|
|
|
|
args.config['content-type']['org']['input-filter'] = args.org_filter
|
|
|
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
default_config = {
|
|
|
|
'content-type':
|
|
|
|
{
|
|
|
|
'org':
|
|
|
|
{
|
|
|
|
'compressed-jsonl': '.notes.jsonl.gz',
|
|
|
|
'embeddings-file': '.note_embeddings.pt'
|
|
|
|
}
|
|
|
|
},
|
|
|
|
'search-type':
|
|
|
|
{
|
|
|
|
'asymmetric':
|
|
|
|
{
|
|
|
|
'encoder': "sentence-transformers/msmarco-MiniLM-L-6-v3",
|
|
|
|
'cross-encoder': "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2021-08-17 13:00:45 +02:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
args = cli()
|
2021-08-22 03:47:55 +02:00
|
|
|
org_config = get_from_dict(args.config, 'content-type', 'org')
|
|
|
|
|
|
|
|
notes_search_enabled = False
|
|
|
|
if 'input-files' in org_config or 'input-filter' in org_config:
|
|
|
|
notes_search_enabled = True
|
|
|
|
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = asymmetric.setup(
|
|
|
|
org_config['input-files'],
|
|
|
|
org_config['input-filter'],
|
|
|
|
pathlib.Path(org_config['compressed-jsonl']),
|
|
|
|
pathlib.Path(org_config['embeddings-file']),
|
|
|
|
args.regenerate,
|
|
|
|
args.verbose)
|
2021-08-16 02:50:08 +02:00
|
|
|
|
|
|
|
|
|
|
|
# Start Application Server
|
|
|
|
uvicorn.run(app)
|