mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Use YAML file to allow user to configure application. Add tests
- YAML Config - Can specify all params[1] earlier being passed via cmd args in config YAML - Can now also configure sentence-transformer models to use etc for search - [1] Config params - org files - compressed entries file config path - embeddings file config path - Include sample_config.yaml - Include sample .org file from this repos readmes - CLI - Configuration Priority: Config via cmd > Config via YAML > Default Config - Test CLI, include test config.yml for the tests - Set default type to None unless set via query param to API Run notes search if search_enabled, also if type is None (default) Prepares for running queries on all search types unless type specified in API query param - Update Readme
This commit is contained in:
parent
bafc86d583
commit
78a1f4ebb4
6 changed files with 141 additions and 32 deletions
|
@ -18,7 +18,7 @@
|
|||
Load ML model, generate embeddings and expose API to query specified org-mode files
|
||||
|
||||
#+begin_src shell
|
||||
python3 src/main.py --input-files ~/Notes/Schedule.org ~/Notes/Incoming.org --verbose
|
||||
python3 src/main.py --org-files ~/Notes/Schedule.org ~/Notes/Incoming.org -c sample_config.yml --verbose
|
||||
#+end_src
|
||||
|
||||
** Use
|
||||
|
|
|
@ -9,4 +9,5 @@ dependencies:
|
|||
- sentence-transformers
|
||||
- fastapi
|
||||
- uvicorn
|
||||
- pyyaml
|
||||
- pytest
|
11
sample_config.yml
Normal file
11
sample_config.yml
Normal file
|
@ -0,0 +1,11 @@
|
|||
content-type:
|
||||
org:
|
||||
input-files: ["src/tests/data/main_readme.org", "src/tests/data/interface_emacs_readme.org"]
|
||||
input-filter: null
|
||||
compressed-jsonl: ".notes.json.gz"
|
||||
embeddings-file: ".note_embeddings.pt"
|
||||
|
||||
search-type:
|
||||
asymmetric:
|
||||
encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3"
|
||||
cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
80
src/main.py
80
src/main.py
|
@ -6,12 +6,13 @@ from typing import Optional
|
|||
|
||||
# External Packages
|
||||
import uvicorn
|
||||
import yaml
|
||||
from fastapi import FastAPI
|
||||
|
||||
# Internal Packages
|
||||
from search_type import asymmetric
|
||||
from processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||
from utils.helpers import is_none_or_empty
|
||||
from utils.helpers import is_none_or_empty, get_absolute_path, get_from_dict, merge_dicts
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
@ -26,7 +27,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
|||
user_query = q
|
||||
results_count = n
|
||||
|
||||
if t == 'notes' or t == None:
|
||||
if (t == 'notes' or t == None) and notes_search_enabled:
|
||||
# query notes
|
||||
hits = asymmetric.query_notes(
|
||||
user_query,
|
||||
|
@ -45,35 +46,90 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
|||
|
||||
@app.get('/regenerate')
|
||||
def regenerate(t: Optional[str] = None):
|
||||
if t == 'notes' or t == None:
|
||||
if (t == 'notes' or t == None) and notes_search_enabled:
|
||||
# Extract Entries, Generate Embeddings
|
||||
global corpus_embeddings
|
||||
global entries
|
||||
entries, corpus_embeddings, _, _, _ = asymmetric.setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, regenerate=True, verbose=args.verbose)
|
||||
entries, corpus_embeddings, _, _, _ = asymmetric.setup(
|
||||
org_config['input-files'],
|
||||
org_config['input-filter'],
|
||||
pathlib.Path(org_config['compressed-jsonl']),
|
||||
pathlib.Path(org_config['embeddings-file']),
|
||||
regenerate=True,
|
||||
verbose=args.verbose)
|
||||
|
||||
|
||||
return {'status': 'ok', 'message': 'regeneration completed'}
|
||||
|
||||
|
||||
def cli(args=None):
|
||||
if not args:
|
||||
if is_none_or_empty(args):
|
||||
args = sys.argv[1:]
|
||||
|
||||
# Setup Argument Parser for the Commandline Interface
|
||||
parser = argparse.ArgumentParser(description="Expose API for Semantic Search")
|
||||
parser.add_argument('--input-files', '-i', nargs='*', help="List of org-mode files to process")
|
||||
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for org-mode files to process")
|
||||
parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path(".notes.jsonl.gz"), help="Compressed JSONL formatted notes file to compute embeddings from")
|
||||
parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path(".notes_embeddings.pt"), help="File to save/load model embeddings to/from")
|
||||
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from org-mode files. Default: false")
|
||||
parser.add_argument('--org-files', '-i', nargs='*', help="List of org-mode files to process")
|
||||
parser.add_argument('--org-filter', type=str, default=None, help="Regex filter for org-mode files to process")
|
||||
parser.add_argument('--config-file', '-c', type=pathlib.Path, help="YAML file with user configuration")
|
||||
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate model embeddings from source files. Default: false")
|
||||
parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs. Default: 0")
|
||||
args = parser.parse_args(args)
|
||||
|
||||
return parser.parse_args(args)
|
||||
if not (args.config_file or args.org_files):
|
||||
print(f"Require at least 1 of --org-file, --org-filter or --config-file flags to be passed from commandline")
|
||||
exit(1)
|
||||
|
||||
# Config Priority: Cmd Args > Config File > Default Config
|
||||
args.config = default_config
|
||||
if args.config_file and args.config_file.exists():
|
||||
with open(get_absolute_path(args.config_file), 'r', encoding='utf-8') as config_file:
|
||||
config_from_file = yaml.safe_load(config_file)
|
||||
args.config = merge_dicts(priority_dict=config_from_file, default_dict=args.config)
|
||||
|
||||
if args.org_files:
|
||||
args.config['content-type']['org']['input-files'] = args.org_files
|
||||
|
||||
if args.org_filter:
|
||||
args.config['content-type']['org']['input-filter'] = args.org_filter
|
||||
|
||||
return args
|
||||
|
||||
|
||||
default_config = {
|
||||
'content-type':
|
||||
{
|
||||
'org':
|
||||
{
|
||||
'compressed-jsonl': '.notes.jsonl.gz',
|
||||
'embeddings-file': '.note_embeddings.pt'
|
||||
}
|
||||
},
|
||||
'search-type':
|
||||
{
|
||||
'asymmetric':
|
||||
{
|
||||
'encoder': "sentence-transformers/msmarco-MiniLM-L-6-v3",
|
||||
'cross-encoder': "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = cli()
|
||||
org_config = get_from_dict(args.config, 'content-type', 'org')
|
||||
|
||||
notes_search_enabled = False
|
||||
if 'input-files' in org_config or 'input-filter' in org_config:
|
||||
notes_search_enabled = True
|
||||
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = asymmetric.setup(
|
||||
org_config['input-files'],
|
||||
org_config['input-filter'],
|
||||
pathlib.Path(org_config['compressed-jsonl']),
|
||||
pathlib.Path(org_config['embeddings-file']),
|
||||
args.regenerate,
|
||||
args.verbose)
|
||||
|
||||
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = asymmetric.setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate, args.verbose)
|
||||
|
||||
# Start Application Server
|
||||
uvicorn.run(app)
|
||||
|
|
11
src/tests/data/config.yml
Normal file
11
src/tests/data/config.yml
Normal file
|
@ -0,0 +1,11 @@
|
|||
content-type:
|
||||
org:
|
||||
input-files: [ "~/first_from_config.org", "~/second_from_config.org" ]
|
||||
input-filter: "*.org"
|
||||
compressed-jsonl: ".notes.json.gz"
|
||||
embeddings-file: ".note_embeddings.pt"
|
||||
|
||||
search-type:
|
||||
asymmetric:
|
||||
encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3"
|
||||
cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
|
@ -32,34 +32,64 @@ def test_asymmetric_setup():
|
|||
assert len(corpus_embeddings) == 10
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_cli_default():
|
||||
def test_cli_minimal_default():
|
||||
# Act
|
||||
args = cli(['--input-files=tests/data/test.org'])
|
||||
actual_args = cli(['--config-file=tests/data/config.yml'])
|
||||
|
||||
# Assert
|
||||
assert args.input_files == ['tests/data/test.org']
|
||||
assert args.input_filter == None
|
||||
assert args.compressed_jsonl == Path('.notes.jsonl.gz')
|
||||
assert args.embeddings == Path('.notes_embeddings.pt')
|
||||
assert args.regenerate == False
|
||||
assert args.verbose == 0
|
||||
|
||||
assert actual_args.config_file == Path('tests/data/config.yml')
|
||||
assert actual_args.regenerate == False
|
||||
assert actual_args.verbose == 0
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_cli_set_by_user():
|
||||
def test_cli_flags():
|
||||
# Act
|
||||
actual_args = cli(['--input-files=tests/data/test.org',
|
||||
'--input-filter=tests/data/*.org',
|
||||
'--compressed-jsonl=tests/data/.test.jsonl.gz',
|
||||
'--embeddings=tests/data/.test_embeddings.pt',
|
||||
actual_args = cli(['--config-file=tests/data/config.yml',
|
||||
'--regenerate',
|
||||
'-vvv'])
|
||||
|
||||
# Assert
|
||||
assert actual_args.input_files == ['tests/data/test.org']
|
||||
assert actual_args.input_filter == 'tests/data/*.org'
|
||||
assert actual_args.compressed_jsonl == Path('tests/data/.test.jsonl.gz')
|
||||
assert actual_args.embeddings == Path('tests/data/.test_embeddings.pt')
|
||||
assert actual_args.config_file == Path('tests/data/config.yml')
|
||||
assert actual_args.regenerate == True
|
||||
assert actual_args.verbose == 3
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_cli_config_from_file():
|
||||
# Act
|
||||
actual_args = cli(['--config-file=tests/data/config.yml',
|
||||
'--regenerate',
|
||||
'-vvv'])
|
||||
|
||||
# Assert
|
||||
assert actual_args.config_file == Path('tests/data/config.yml')
|
||||
assert actual_args.regenerate == True
|
||||
assert actual_args.config is not None
|
||||
assert actual_args.config['content-type']['org']['input-files'] == ['~/first_from_config.org', '~/second_from_config.org']
|
||||
assert actual_args.verbose == 3
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_cli_config_from_cmd_args():
|
||||
""
|
||||
# Act
|
||||
actual_args = cli(['--org-files=first.org'])
|
||||
|
||||
# Assert
|
||||
assert actual_args.org_files == ['first.org']
|
||||
assert actual_args.config_file is None
|
||||
assert actual_args.config is not None
|
||||
assert actual_args.config['content-type']['org']['input-files'] == ['first.org']
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_cli_config_from_cmd_args_override_config_file():
|
||||
# Act
|
||||
actual_args = cli(['--config-file=tests/data/config.yml',
|
||||
'--org-files=first.org'])
|
||||
|
||||
# Assert
|
||||
assert actual_args.org_files == ['first.org']
|
||||
assert actual_args.config_file == Path('tests/data/config.yml')
|
||||
assert actual_args.config is not None
|
||||
assert actual_args.config['content-type']['org']['input-files'] == ['first.org']
|
||||
|
|
Loading…
Reference in a new issue