mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Use YAML file to allow user to configure application. Add tests
- YAML Config - Can specify all params[1] earlier being passed via cmd args in config YAML - Can now also configure sentence-transformer models to use etc for search - [1] Config params - org files - compressed entries file config path - embeddings file config path - Include sample_config.yaml - Include sample .org file from this repos readmes - CLI - Configuration Priority: Config via cmd > Config via YAML > Default Config - Test CLI, include test config.yml for the tests - Set default type to None unless set via query param to API Run notes search if search_enabled, also if type is None (default) Prepares for running queries on all search types unless type specified in API query param - Update Readme
This commit is contained in:
parent
bafc86d583
commit
78a1f4ebb4
6 changed files with 141 additions and 32 deletions
|
@ -18,7 +18,7 @@
|
||||||
Load ML model, generate embeddings and expose API to query specified org-mode files
|
Load ML model, generate embeddings and expose API to query specified org-mode files
|
||||||
|
|
||||||
#+begin_src shell
|
#+begin_src shell
|
||||||
python3 src/main.py --input-files ~/Notes/Schedule.org ~/Notes/Incoming.org --verbose
|
python3 src/main.py --org-files ~/Notes/Schedule.org ~/Notes/Incoming.org -c sample_config.yml --verbose
|
||||||
#+end_src
|
#+end_src
|
||||||
|
|
||||||
** Use
|
** Use
|
||||||
|
|
|
@ -9,4 +9,5 @@ dependencies:
|
||||||
- sentence-transformers
|
- sentence-transformers
|
||||||
- fastapi
|
- fastapi
|
||||||
- uvicorn
|
- uvicorn
|
||||||
|
- pyyaml
|
||||||
- pytest
|
- pytest
|
11
sample_config.yml
Normal file
11
sample_config.yml
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
content-type:
|
||||||
|
org:
|
||||||
|
input-files: ["src/tests/data/main_readme.org", "src/tests/data/interface_emacs_readme.org"]
|
||||||
|
input-filter: null
|
||||||
|
compressed-jsonl: ".notes.json.gz"
|
||||||
|
embeddings-file: ".note_embeddings.pt"
|
||||||
|
|
||||||
|
search-type:
|
||||||
|
asymmetric:
|
||||||
|
encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3"
|
||||||
|
cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
80
src/main.py
80
src/main.py
|
@ -6,12 +6,13 @@ from typing import Optional
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
import yaml
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from search_type import asymmetric
|
from search_type import asymmetric
|
||||||
from processor.org_mode.org_to_jsonl import org_to_jsonl
|
from processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||||
from utils.helpers import is_none_or_empty
|
from utils.helpers import is_none_or_empty, get_absolute_path, get_from_dict, merge_dicts
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
@ -26,7 +27,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
||||||
user_query = q
|
user_query = q
|
||||||
results_count = n
|
results_count = n
|
||||||
|
|
||||||
if t == 'notes' or t == None:
|
if (t == 'notes' or t == None) and notes_search_enabled:
|
||||||
# query notes
|
# query notes
|
||||||
hits = asymmetric.query_notes(
|
hits = asymmetric.query_notes(
|
||||||
user_query,
|
user_query,
|
||||||
|
@ -45,35 +46,90 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
||||||
|
|
||||||
@app.get('/regenerate')
|
@app.get('/regenerate')
|
||||||
def regenerate(t: Optional[str] = None):
|
def regenerate(t: Optional[str] = None):
|
||||||
if t == 'notes' or t == None:
|
if (t == 'notes' or t == None) and notes_search_enabled:
|
||||||
# Extract Entries, Generate Embeddings
|
# Extract Entries, Generate Embeddings
|
||||||
global corpus_embeddings
|
global corpus_embeddings
|
||||||
global entries
|
global entries
|
||||||
entries, corpus_embeddings, _, _, _ = asymmetric.setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, regenerate=True, verbose=args.verbose)
|
entries, corpus_embeddings, _, _, _ = asymmetric.setup(
|
||||||
|
org_config['input-files'],
|
||||||
|
org_config['input-filter'],
|
||||||
|
pathlib.Path(org_config['compressed-jsonl']),
|
||||||
|
pathlib.Path(org_config['embeddings-file']),
|
||||||
|
regenerate=True,
|
||||||
|
verbose=args.verbose)
|
||||||
|
|
||||||
|
|
||||||
return {'status': 'ok', 'message': 'regeneration completed'}
|
return {'status': 'ok', 'message': 'regeneration completed'}
|
||||||
|
|
||||||
|
|
||||||
def cli(args=None):
|
def cli(args=None):
|
||||||
if not args:
|
if is_none_or_empty(args):
|
||||||
args = sys.argv[1:]
|
args = sys.argv[1:]
|
||||||
|
|
||||||
# Setup Argument Parser for the Commandline Interface
|
# Setup Argument Parser for the Commandline Interface
|
||||||
parser = argparse.ArgumentParser(description="Expose API for Semantic Search")
|
parser = argparse.ArgumentParser(description="Expose API for Semantic Search")
|
||||||
parser.add_argument('--input-files', '-i', nargs='*', help="List of org-mode files to process")
|
parser.add_argument('--org-files', '-i', nargs='*', help="List of org-mode files to process")
|
||||||
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for org-mode files to process")
|
parser.add_argument('--org-filter', type=str, default=None, help="Regex filter for org-mode files to process")
|
||||||
parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path(".notes.jsonl.gz"), help="Compressed JSONL formatted notes file to compute embeddings from")
|
parser.add_argument('--config-file', '-c', type=pathlib.Path, help="YAML file with user configuration")
|
||||||
parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path(".notes_embeddings.pt"), help="File to save/load model embeddings to/from")
|
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate model embeddings from source files. Default: false")
|
||||||
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from org-mode files. Default: false")
|
|
||||||
parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs. Default: 0")
|
parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs. Default: 0")
|
||||||
|
args = parser.parse_args(args)
|
||||||
|
|
||||||
return parser.parse_args(args)
|
if not (args.config_file or args.org_files):
|
||||||
|
print(f"Require at least 1 of --org-file, --org-filter or --config-file flags to be passed from commandline")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# Config Priority: Cmd Args > Config File > Default Config
|
||||||
|
args.config = default_config
|
||||||
|
if args.config_file and args.config_file.exists():
|
||||||
|
with open(get_absolute_path(args.config_file), 'r', encoding='utf-8') as config_file:
|
||||||
|
config_from_file = yaml.safe_load(config_file)
|
||||||
|
args.config = merge_dicts(priority_dict=config_from_file, default_dict=args.config)
|
||||||
|
|
||||||
|
if args.org_files:
|
||||||
|
args.config['content-type']['org']['input-files'] = args.org_files
|
||||||
|
|
||||||
|
if args.org_filter:
|
||||||
|
args.config['content-type']['org']['input-filter'] = args.org_filter
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
default_config = {
|
||||||
|
'content-type':
|
||||||
|
{
|
||||||
|
'org':
|
||||||
|
{
|
||||||
|
'compressed-jsonl': '.notes.jsonl.gz',
|
||||||
|
'embeddings-file': '.note_embeddings.pt'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'search-type':
|
||||||
|
{
|
||||||
|
'asymmetric':
|
||||||
|
{
|
||||||
|
'encoder': "sentence-transformers/msmarco-MiniLM-L-6-v3",
|
||||||
|
'cross-encoder': "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = cli()
|
args = cli()
|
||||||
|
org_config = get_from_dict(args.config, 'content-type', 'org')
|
||||||
|
|
||||||
|
notes_search_enabled = False
|
||||||
|
if 'input-files' in org_config or 'input-filter' in org_config:
|
||||||
|
notes_search_enabled = True
|
||||||
|
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = asymmetric.setup(
|
||||||
|
org_config['input-files'],
|
||||||
|
org_config['input-filter'],
|
||||||
|
pathlib.Path(org_config['compressed-jsonl']),
|
||||||
|
pathlib.Path(org_config['embeddings-file']),
|
||||||
|
args.regenerate,
|
||||||
|
args.verbose)
|
||||||
|
|
||||||
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = asymmetric.setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate, args.verbose)
|
|
||||||
|
|
||||||
# Start Application Server
|
# Start Application Server
|
||||||
uvicorn.run(app)
|
uvicorn.run(app)
|
||||||
|
|
11
src/tests/data/config.yml
Normal file
11
src/tests/data/config.yml
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
content-type:
|
||||||
|
org:
|
||||||
|
input-files: [ "~/first_from_config.org", "~/second_from_config.org" ]
|
||||||
|
input-filter: "*.org"
|
||||||
|
compressed-jsonl: ".notes.json.gz"
|
||||||
|
embeddings-file: ".note_embeddings.pt"
|
||||||
|
|
||||||
|
search-type:
|
||||||
|
asymmetric:
|
||||||
|
encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3"
|
||||||
|
cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
|
@ -32,34 +32,64 @@ def test_asymmetric_setup():
|
||||||
assert len(corpus_embeddings) == 10
|
assert len(corpus_embeddings) == 10
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
def test_cli_minimal_default():
|
||||||
def test_cli_default():
|
|
||||||
# Act
|
# Act
|
||||||
args = cli(['--input-files=tests/data/test.org'])
|
actual_args = cli(['--config-file=tests/data/config.yml'])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert args.input_files == ['tests/data/test.org']
|
assert actual_args.config_file == Path('tests/data/config.yml')
|
||||||
assert args.input_filter == None
|
assert actual_args.regenerate == False
|
||||||
assert args.compressed_jsonl == Path('.notes.jsonl.gz')
|
assert actual_args.verbose == 0
|
||||||
assert args.embeddings == Path('.notes_embeddings.pt')
|
|
||||||
assert args.regenerate == False
|
|
||||||
assert args.verbose == 0
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_cli_set_by_user():
|
def test_cli_flags():
|
||||||
# Act
|
# Act
|
||||||
actual_args = cli(['--input-files=tests/data/test.org',
|
actual_args = cli(['--config-file=tests/data/config.yml',
|
||||||
'--input-filter=tests/data/*.org',
|
|
||||||
'--compressed-jsonl=tests/data/.test.jsonl.gz',
|
|
||||||
'--embeddings=tests/data/.test_embeddings.pt',
|
|
||||||
'--regenerate',
|
'--regenerate',
|
||||||
'-vvv'])
|
'-vvv'])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert actual_args.input_files == ['tests/data/test.org']
|
assert actual_args.config_file == Path('tests/data/config.yml')
|
||||||
assert actual_args.input_filter == 'tests/data/*.org'
|
|
||||||
assert actual_args.compressed_jsonl == Path('tests/data/.test.jsonl.gz')
|
|
||||||
assert actual_args.embeddings == Path('tests/data/.test_embeddings.pt')
|
|
||||||
assert actual_args.regenerate == True
|
assert actual_args.regenerate == True
|
||||||
assert actual_args.verbose == 3
|
assert actual_args.verbose == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def test_cli_config_from_file():
|
||||||
|
# Act
|
||||||
|
actual_args = cli(['--config-file=tests/data/config.yml',
|
||||||
|
'--regenerate',
|
||||||
|
'-vvv'])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert actual_args.config_file == Path('tests/data/config.yml')
|
||||||
|
assert actual_args.regenerate == True
|
||||||
|
assert actual_args.config is not None
|
||||||
|
assert actual_args.config['content-type']['org']['input-files'] == ['~/first_from_config.org', '~/second_from_config.org']
|
||||||
|
assert actual_args.verbose == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def test_cli_config_from_cmd_args():
|
||||||
|
""
|
||||||
|
# Act
|
||||||
|
actual_args = cli(['--org-files=first.org'])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert actual_args.org_files == ['first.org']
|
||||||
|
assert actual_args.config_file is None
|
||||||
|
assert actual_args.config is not None
|
||||||
|
assert actual_args.config['content-type']['org']['input-files'] == ['first.org']
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def test_cli_config_from_cmd_args_override_config_file():
|
||||||
|
# Act
|
||||||
|
actual_args = cli(['--config-file=tests/data/config.yml',
|
||||||
|
'--org-files=first.org'])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert actual_args.org_files == ['first.org']
|
||||||
|
assert actual_args.config_file == Path('tests/data/config.yml')
|
||||||
|
assert actual_args.config is not None
|
||||||
|
assert actual_args.config['content-type']['org']['input-files'] == ['first.org']
|
||||||
|
|
Loading…
Reference in a new issue