From 10e4065e053c7041f4fd04a1a760af806d0655ce Mon Sep 17 00:00:00 2001 From: Saba Date: Sat, 4 Dec 2021 11:43:48 -0500 Subject: [PATCH] Consolidate the search config models and pass verbose as a top level flag --- src/main.py | 71 ++++++++++++----------------- src/search_type/asymmetric.py | 13 +++--- src/search_type/image_search.py | 11 +++-- src/search_type/symmetric_ledger.py | 9 ++-- src/utils/cli.py | 4 +- src/utils/config.py | 34 ++------------ 6 files changed, 52 insertions(+), 90 deletions(-) diff --git a/src/main.py b/src/main.py index fc9afa1f..abbe923e 100644 --- a/src/main.py +++ b/src/main.py @@ -13,16 +13,16 @@ from fastapi.templating import Jinja2Templates from src.search_type import asymmetric, symmetric_ledger, image_search from src.utils.helpers import get_absolute_path from src.utils.cli import cli -from src.utils.config import SearchType, SearchModels, TextSearchConfig, ImageSearchConfig, SearchConfig, ProcessorConfig, ConversationProcessorConfig -from src.utils.rawconfig import FullConfig +from src.utils.config import SearchType, SearchModels, ProcessorConfig, ConversationProcessorConfigDTO +from src.utils.rawconfig import FullConfigModel from src.processor.conversation.gpt import converse, message_to_log, message_to_prompt, understand # Application Global State model = SearchModels() -search_config = SearchConfig() processor_config = ProcessorConfig() config = {} config_file = "" +verbose = 0 app = FastAPI() app.mount("/views", StaticFiles(directory="views"), name="views") @@ -32,12 +32,12 @@ templates = Jinja2Templates(directory="views/") def ui(request: Request): return templates.TemplateResponse("config.html", context={'request': request}) -@app.get('/config', response_model=FullConfig) +@app.get('/config', response_model=FullConfigModel) def config(): return config @app.post('/config') -async def config(updated_config: FullConfig): +async def config(updated_config: FullConfigModel): global config config = updated_config with open(config_file, 'w') as outfile: @@ -83,7 +83,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): return image_search.collate_results( hits, model.image_search.image_names, - search_config.image.input_directory, + config.content_type.image.input_directory, results_count) else: @@ -92,22 +92,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): @app.get('/regenerate') def regenerate(t: Optional[SearchType] = None): - if (t == SearchType.Notes or t == None) and search_config.notes: - # Extract Entries, Generate Embeddings - model.notes_search = asymmetric.setup(search_config.notes, regenerate=True) - - if (t == SearchType.Music or t == None) and search_config.music: - # Extract Entries, Generate Song Embeddings - model.music_search = asymmetric.setup(search_config.music, regenerate=True) - - if (t == SearchType.Ledger or t == None) and search_config.ledger: - # Extract Entries, Generate Embeddings - model.ledger_search = symmetric_ledger.setup(search_config.ledger, regenerate=True) - - if (t == SearchType.Image or t == None) and search_config.image: - # Extract Images, Generate Embeddings - model.image_search = image_search.setup(search_config.image, regenerate=True) - + initialize_search(regenerate=True) return {'status': 'ok', 'message': 'regeneration completed'} @@ -128,41 +113,40 @@ def chat(q: str): return {'status': 'ok', 'response': gpt_response} -def initialize_search(regenerate, verbose): +def initialize_search(regenerate: bool, t: SearchType = None): model = SearchModels() - search_config = SearchConfig() # Initialize Org Notes Search - if config.content_type.org: - search_config.notes = TextSearchConfig(config.content_type.org, verbose) - model.notes_search = asymmetric.setup(search_config.notes, regenerate=regenerate) + if (t == SearchType.Notes or t == None) and config.content_type.org: + # Extract Entries, Generate Notes Embeddings + model.notes_search = asymmetric.setup(config.content_type.org, regenerate=regenerate, verbose=verbose) # Initialize Org Music Search - if config.content_type.music: - search_config.music = TextSearchConfig(config.content_type.music, verbose) - model.music_search = asymmetric.setup(search_config.music, regenerate=regenerate) + if (t == SearchType.Music or t == None) and config.content_type.music: + # Extract Entries, Generate Music Embeddings + model.music_search = asymmetric.setup(config.content_type.music, regenerate=regenerate, verbose=verbose) # Initialize Ledger Search - if config.content_type.ledger: - search_config.ledger = TextSearchConfig(config.content_type.org, verbose) - model.ledger_search = symmetric_ledger.setup(search_config.ledger, regenerate=regenerate) + if (t == SearchType.Ledger or t == None) and config.content_type.ledger: + # Extract Entries, Generate Ledger Embeddings + model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, regenerate=regenerate, verbose=verbose) # Initialize Image Search - if config.content_type.image: - search_config.image = ImageSearchConfig(config.content_type.image, verbose) - model.image_search = image_search.setup(search_config.image, regenerate=regenerate) + if (t == SearchType.Image or t == None) and config.content_type.image: + # Extract Entries, Generate Image Embeddings + model.image_search = image_search.setup(config.content_type.image, regenerate=regenerate, verbose=verbose) - return model, search_config + return model -def initialize_processor(verbose): +def initialize_processor(): if not config.processor: return processor_config = ProcessorConfig() # Initialize Conversation Processor - processor_config.conversation = ConversationProcessorConfig(config.processor.conversation, verbose) + processor_config.conversation = ConversationProcessorConfigDTO(config.processor.conversation, verbose) conversation_logfile = processor_config.conversation.conversation_logfile if processor_config.conversation.verbose: @@ -211,14 +195,17 @@ if __name__ == '__main__': # Stores the file path to the config file. config_file = args.config_file + # Store the verbose flag + verbose = args.verbose + # Store the raw config data. config = args.config - # Initialize Search from Config - model, search_config = initialize_search(args.regenerate, args.verbose) + # Initialize the search model from Config + model = initialize_search(args.regenerate) # Initialize Processor from Config - processor_config = initialize_processor(args.verbose) + processor_config = initialize_processor() # Start Application Server if args.socket: diff --git a/src/search_type/asymmetric.py b/src/search_type/asymmetric.py index fc4e72e4..b65d667f 100644 --- a/src/search_type/asymmetric.py +++ b/src/search_type/asymmetric.py @@ -14,7 +14,8 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util # Internal Packages from src.utils.helpers import get_absolute_path, resolve_absolute_path from src.processor.org_mode.org_to_jsonl import org_to_jsonl -from src.utils.config import TextSearchModel, TextSearchConfig +from src.utils.config import TextSearchModel +from src.utils.rawconfig import TextSearchConfigModel def initialize_model(): @@ -148,22 +149,22 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(config: TextSearchConfig, regenerate: bool) -> TextSearchModel: +def setup(config: TextSearchConfigModel, regenerate: bool, verbose: bool) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model() # Map notes in Org-Mode files to (compressed) JSONL formatted file if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate: - org_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, config.verbose) + org_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose) # Extract Entries - entries = extract_entries(config.compressed_jsonl, config.verbose) + entries = extract_entries(config.compressed_jsonl, verbose) top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus # Compute or Load Embeddings - corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=config.verbose) + corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose) - return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=config.verbose) + return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose) if __name__ == '__main__': diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index 95acc801..1b7341a5 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -10,9 +10,10 @@ from tqdm import trange import torch # Internal Packages -from src.utils.helpers import get_absolute_path, resolve_absolute_path +from src.utils.helpers import resolve_absolute_path import src.utils.exiftool as exiftool -from src.utils.config import ImageSearchModel, ImageSearchConfig +from src.utils.config import ImageSearchModel +from src.utils.rawconfig import ImageSearchConfigModel def initialize_model(): @@ -153,7 +154,7 @@ def collate_results(hits, image_names, image_directory, count=5): in hits[0:count]] -def setup(config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel: +def setup(config: ImageSearchConfigModel, regenerate: bool, verbose: bool) -> ImageSearchModel: # Initialize Model encoder = initialize_model() @@ -170,13 +171,13 @@ def setup(config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel: batch_size=config.batch_size, regenerate=regenerate, use_xmp_metadata=config.use_xmp_metadata, - verbose=config.verbose) + verbose=verbose) return ImageSearchModel(image_names, image_embeddings, image_metadata_embeddings, encoder, - config.verbose) + verbose) if __name__ == '__main__': diff --git a/src/search_type/symmetric_ledger.py b/src/search_type/symmetric_ledger.py index cbd3c42b..c507bbf5 100644 --- a/src/search_type/symmetric_ledger.py +++ b/src/search_type/symmetric_ledger.py @@ -12,7 +12,8 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util # Internal Packages from src.utils.helpers import get_absolute_path, resolve_absolute_path from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl -from src.utils.config import TextSearchModel, TextSearchConfig +from src.utils.config import TextSearchModel +from src.utils.rawconfig import TextSearchConfigModel def initialize_model(): @@ -140,7 +141,7 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(config: TextSearchConfig, regenerate: bool) -> TextSearchModel: +def setup(config: TextSearchConfigModel, regenerate: bool, verbose: bool) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model() @@ -153,9 +154,9 @@ def setup(config: TextSearchConfig, regenerate: bool) -> TextSearchModel: top_k = min(len(entries), top_k) # Compute or Load Embeddings - corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=config.verbose) + corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose) - return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=config.verbose) + return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose) if __name__ == '__main__': diff --git a/src/utils/cli.py b/src/utils/cli.py index 504a6d99..00d0566e 100644 --- a/src/utils/cli.py +++ b/src/utils/cli.py @@ -8,7 +8,7 @@ import yaml # Internal Packages from src.utils.helpers import is_none_or_empty, get_absolute_path, resolve_absolute_path, merge_dicts -from src.utils.rawconfig import FullConfig +from src.utils.rawconfig import FullConfigModel def cli(args=None): if is_none_or_empty(args): @@ -37,7 +37,7 @@ def cli(args=None): with open(get_absolute_path(args.config_file), 'r', encoding='utf-8') as config_file: config_from_file = yaml.safe_load(config_file) args.config = merge_dicts(priority_dict=config_from_file, default_dict=args.config) - args.config = FullConfig.parse_obj(args.config) + args.config = FullConfigModel.parse_obj(args.config) if args.org_files: args.config['content-type']['org']['input-files'] = args.org_files diff --git a/src/utils/config.py b/src/utils/config.py index c40e7ef0..02cb53f2 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -4,9 +4,7 @@ from dataclasses import dataclass from pathlib import Path # Internal Packages -from src.utils.helpers import get_from_dict - -from src.utils.rawconfig import TextSearchConfigModel, ImageSearchConfigModel, ProcessorConversationConfigModel +from src.utils.rawconfig import ProcessorConversationConfigModel class SearchType(str, Enum): @@ -44,33 +42,7 @@ class SearchModels(): image_search: ImageSearchModel = None -class TextSearchConfigModel(): - def __init__(self, text_search_config: TextSearchConfigModel, verbose: bool): - self.input_files = text_search_config.input_files - self.input_filter = text_search_config.input_filter - self.compressed_jsonl = Path(text_search_config.compressed_jsonl) - self.embeddings_file = Path(text_search_config.embeddings_file) - self.verbose = verbose - - -class ImageSearchConfigModel(): - def __init__(self, image_search_config: ImageSearchConfigModel, verbose): - self.input_directory = Path(image_search_config.input_directory) - self.embeddings_file = Path(image_search_config.embeddings_file) - self.batch_size = image_search_config.batch_size - self.use_xmp_metadata = image_search_config.use_xmp_metadata - self.verbose = verbose - - -@dataclass -class SearchConfig(): - notes: TextSearchConfigModel = None - ledger: TextSearchConfigModel = None - music: TextSearchConfigModel = None - image: ImageSearchConfigModel = None - - -class ConversationProcessorConfig(): +class ConversationProcessorConfigDTO(): def __init__(self, processor_config: ProcessorConversationConfigModel, verbose: bool): self.openai_api_key = processor_config.open_api_key self.conversation_logfile = Path(processor_config.conversation_logfile) @@ -81,4 +53,4 @@ class ConversationProcessorConfig(): @dataclass class ProcessorConfig(): - conversation: ConversationProcessorConfig = None + conversation: ConversationProcessorConfigDTO = None