Update unit tests, files with removing model suffix to config types

This commit is contained in:
Saba 2021-12-09 08:50:38 -05:00
parent b3eac888fb
commit d65190c3ee
9 changed files with 55 additions and 55 deletions

View file

@ -13,13 +13,13 @@ from fastapi.templating import Jinja2Templates
from src.search_type import asymmetric, symmetric_ledger, image_search from src.search_type import asymmetric, symmetric_ledger, image_search
from src.utils.helpers import get_absolute_path from src.utils.helpers import get_absolute_path
from src.utils.cli import cli from src.utils.cli import cli
from src.utils.config import SearchType, SearchModels, ProcessorConfig, ConversationProcessorConfigDTO from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from src.utils.rawconfig import FullConfigModel from src.utils.rawconfig import FullConfig
from src.processor.conversation.gpt import converse, message_to_log, message_to_prompt, understand from src.processor.conversation.gpt import converse, message_to_log, message_to_prompt, understand
# Application Global State # Application Global State
model = SearchModels() model = SearchModels()
processor_config = ProcessorConfig() processor_config = ProcessorConfigModel()
config = {} config = {}
config_file = "" config_file = ""
verbose = 0 verbose = 0
@ -32,12 +32,12 @@ templates = Jinja2Templates(directory="views/")
def ui(request: Request): def ui(request: Request):
return templates.TemplateResponse("config.html", context={'request': request}) return templates.TemplateResponse("config.html", context={'request': request})
@app.get('/config', response_model=FullConfigModel) @app.get('/config', response_model=FullConfig)
def config(): def config():
return config return config
@app.post('/config') @app.post('/config')
async def config(updated_config: FullConfigModel): async def config(updated_config: FullConfig):
global config global config
config = updated_config config = updated_config
with open(config_file, 'w') as outfile: with open(config_file, 'w') as outfile:
@ -92,7 +92,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
@app.get('/regenerate') @app.get('/regenerate')
def regenerate(t: Optional[SearchType] = None): def regenerate(t: Optional[SearchType] = None):
initialize_search(regenerate=True) initialize_search(regenerate=True, t=t)
return {'status': 'ok', 'message': 'regeneration completed'} return {'status': 'ok', 'message': 'regeneration completed'}
@ -113,7 +113,7 @@ def chat(q: str):
return {'status': 'ok', 'response': gpt_response} return {'status': 'ok', 'response': gpt_response}
def initialize_search(regenerate: bool, t: SearchType = None): def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None):
model = SearchModels() model = SearchModels()
# Initialize Org Notes Search # Initialize Org Notes Search
@ -139,14 +139,14 @@ def initialize_search(regenerate: bool, t: SearchType = None):
return model return model
def initialize_processor(): def initialize_processor(config: FullConfig):
if not config.processor: if not config.processor:
return return
processor_config = ProcessorConfig() processor_config = ProcessorConfigModel()
# Initialize Conversation Processor # Initialize Conversation Processor
processor_config.conversation = ConversationProcessorConfigDTO(config.processor.conversation, verbose) processor_config.conversation = ConversationProcessorConfigModel(config.processor.conversation, verbose)
conversation_logfile = processor_config.conversation.conversation_logfile conversation_logfile = processor_config.conversation.conversation_logfile
if processor_config.conversation.verbose: if processor_config.conversation.verbose:
@ -202,10 +202,10 @@ if __name__ == '__main__':
config = args.config config = args.config
# Initialize the search model from Config # Initialize the search model from Config
model = initialize_search(args.regenerate) model = initialize_search(args.config, args.regenerate)
# Initialize Processor from Config # Initialize Processor from Config
processor_config = initialize_processor() processor_config = initialize_processor(args.config)
# Start Application Server # Start Application Server
if args.socket: if args.socket:

View file

@ -15,7 +15,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
from src.utils.helpers import get_absolute_path, resolve_absolute_path from src.utils.helpers import get_absolute_path, resolve_absolute_path
from src.processor.org_mode.org_to_jsonl import org_to_jsonl from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.utils.config import TextSearchModel from src.utils.config import TextSearchModel
from src.utils.rawconfig import TextSearchConfigModel from src.utils.rawconfig import TextSearchConfig
def initialize_model(): def initialize_model():
@ -149,7 +149,7 @@ def collate_results(hits, entries, count=5):
in hits[0:count]] in hits[0:count]]
def setup(config: TextSearchConfigModel, regenerate: bool, verbose: bool) -> TextSearchModel: def setup(config: TextSearchConfig, regenerate: bool, verbose: bool) -> TextSearchModel:
# Initialize Model # Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model() bi_encoder, cross_encoder, top_k = initialize_model()

View file

@ -13,7 +13,7 @@ import torch
from src.utils.helpers import resolve_absolute_path from src.utils.helpers import resolve_absolute_path
import src.utils.exiftool as exiftool import src.utils.exiftool as exiftool
from src.utils.config import ImageSearchModel from src.utils.config import ImageSearchModel
from src.utils.rawconfig import ImageSearchConfigModel from src.utils.rawconfig import ImageSearchConfig
def initialize_model(): def initialize_model():
@ -154,7 +154,7 @@ def collate_results(hits, image_names, image_directory, count=5):
in hits[0:count]] in hits[0:count]]
def setup(config: ImageSearchConfigModel, regenerate: bool, verbose: bool) -> ImageSearchModel: def setup(config: ImageSearchConfig, regenerate: bool, verbose: bool) -> ImageSearchModel:
# Initialize Model # Initialize Model
encoder = initialize_model() encoder = initialize_model()

View file

@ -13,7 +13,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
from src.utils.helpers import get_absolute_path, resolve_absolute_path from src.utils.helpers import get_absolute_path, resolve_absolute_path
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl
from src.utils.config import TextSearchModel from src.utils.config import TextSearchModel
from src.utils.rawconfig import TextSearchConfigModel from src.utils.rawconfig import TextSearchConfig
def initialize_model(): def initialize_model():
@ -141,7 +141,7 @@ def collate_results(hits, entries, count=5):
in hits[0:count]] in hits[0:count]]
def setup(config: TextSearchConfigModel, regenerate: bool, verbose: bool) -> TextSearchModel: def setup(config: TextSearchConfig, regenerate: bool, verbose: bool) -> TextSearchModel:
# Initialize Model # Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model() bi_encoder, cross_encoder, top_k = initialize_model()

View file

@ -8,7 +8,7 @@ import yaml
# Internal Packages # Internal Packages
from src.utils.helpers import is_none_or_empty, get_absolute_path, resolve_absolute_path, merge_dicts from src.utils.helpers import is_none_or_empty, get_absolute_path, resolve_absolute_path, merge_dicts
from src.utils.rawconfig import FullConfigModel from src.utils.rawconfig import FullConfig
def cli(args=None): def cli(args=None):
if is_none_or_empty(args): if is_none_or_empty(args):
@ -37,9 +37,9 @@ def cli(args=None):
with open(get_absolute_path(args.config_file), 'r', encoding='utf-8') as config_file: with open(get_absolute_path(args.config_file), 'r', encoding='utf-8') as config_file:
config_from_file = yaml.safe_load(config_file) config_from_file = yaml.safe_load(config_file)
args.config = merge_dicts(priority_dict=config_from_file, default_dict=args.config) args.config = merge_dicts(priority_dict=config_from_file, default_dict=args.config)
args.config = FullConfigModel.parse_obj(args.config) args.config = FullConfig.parse_obj(args.config)
else: else:
args.config = FullConfigModel.parse_obj(args.config) args.config = FullConfig.parse_obj(args.config)
if args.org_files: if args.org_files:
args.config.content_type.org.input_files = args.org_files args.config.content_type.org.input_files = args.org_files

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from pathlib import Path from pathlib import Path
# Internal Packages # Internal Packages
from src.utils.rawconfig import ProcessorConversationConfigModel from src.utils.rawconfig import ConversationProcessorConfig
class SearchType(str, Enum): class SearchType(str, Enum):
@ -42,8 +42,8 @@ class SearchModels():
image_search: ImageSearchModel = None image_search: ImageSearchModel = None
class ConversationProcessorConfigDTO(): class ConversationProcessorConfigModel():
def __init__(self, processor_config: ProcessorConversationConfigModel, verbose: bool): def __init__(self, processor_config: ConversationProcessorConfig, verbose: bool):
self.openai_api_key = processor_config.open_api_key self.openai_api_key = processor_config.open_api_key
self.conversation_logfile = Path(processor_config.conversation_logfile) self.conversation_logfile = Path(processor_config.conversation_logfile)
self.chat_log = '' self.chat_log = ''
@ -52,5 +52,5 @@ class ConversationProcessorConfigDTO():
self.verbose = verbose self.verbose = verbose
@dataclass @dataclass
class ProcessorConfig(): class ProcessorConfigModel():
conversation: ConversationProcessorConfigDTO = None conversation: ConversationProcessorConfigModel = None

View file

@ -8,55 +8,55 @@ from pydantic import BaseModel
# Internal Packages # Internal Packages
from src.utils.helpers import to_snake_case_from_dash from src.utils.helpers import to_snake_case_from_dash
class ConfigBaseModel(BaseModel): class ConfigBase(BaseModel):
class Config: class Config:
alias_generator = to_snake_case_from_dash alias_generator = to_snake_case_from_dash
allow_population_by_field_name = True allow_population_by_field_name = True
class SearchConfigModel(ConfigBaseModel): class SearchConfig(ConfigBase):
input_files: Optional[List[str]] input_files: Optional[List[str]]
input_filter: Optional[str] input_filter: Optional[str]
embeddings_file: Optional[Path] embeddings_file: Optional[Path]
class TextSearchConfigModel(ConfigBaseModel): class TextSearchConfig(ConfigBase):
compressed_jsonl: Optional[Path] compressed_jsonl: Optional[Path]
input_files: Optional[List[str]] input_files: Optional[List[str]]
input_filter: Optional[str] input_filter: Optional[str]
embeddings_file: Optional[Path] embeddings_file: Optional[Path]
class ImageSearchConfigModel(ConfigBaseModel): class ImageSearchConfig(ConfigBase):
use_xmp_metadata: Optional[str] use_xmp_metadata: Optional[str]
batch_size: Optional[int] batch_size: Optional[int]
input_directory: Optional[Path] input_directory: Optional[Path]
input_filter: Optional[str] input_filter: Optional[str]
embeddings_file: Optional[Path] embeddings_file: Optional[Path]
class ContentTypeModel(ConfigBaseModel): class ContentTypeConfig(ConfigBase):
org: Optional[TextSearchConfigModel] org: Optional[TextSearchConfig]
ledger: Optional[TextSearchConfigModel] ledger: Optional[TextSearchConfig]
image: Optional[ImageSearchConfigModel] image: Optional[ImageSearchConfig]
music: Optional[TextSearchConfigModel] music: Optional[TextSearchConfig]
class AsymmetricConfigModel(ConfigBaseModel): class AsymmetricConfig(ConfigBase):
encoder: Optional[str] encoder: Optional[str]
cross_encoder: Optional[str] cross_encoder: Optional[str]
class ImageSearchTypeConfigModel(ConfigBaseModel): class ImageSearchTypeConfig(ConfigBase):
encoder: Optional[str] encoder: Optional[str]
class SearchTypeConfigModel(ConfigBaseModel): class SearchTypeConfig(ConfigBase):
asymmetric: Optional[AsymmetricConfigModel] asymmetric: Optional[AsymmetricConfig]
image: Optional[ImageSearchTypeConfigModel] image: Optional[ImageSearchTypeConfig]
class ProcessorConversationConfigModel(ConfigBaseModel): class ConversationProcessorConfig(ConfigBase):
open_api_key: Optional[str] open_api_key: Optional[str]
conversation_logfile: Optional[str] conversation_logfile: Optional[str]
conversation_history: Optional[str] conversation_history: Optional[str]
class ProcessorConfigModel(ConfigBaseModel): class ProcessorConfigModel(ConfigBase):
conversation: Optional[ProcessorConversationConfigModel] conversation: Optional[ConversationProcessorConfig]
class FullConfigModel(ConfigBaseModel): class FullConfig(ConfigBase):
content_type: Optional[ContentTypeModel] content_type: Optional[ContentTypeConfig]
search_type: Optional[SearchTypeConfigModel] search_type: Optional[SearchTypeConfig]
processor: Optional[ProcessorConfigModel] processor: Optional[ProcessorConfigModel]

View file

@ -4,7 +4,7 @@ from pathlib import Path
# Internal Packages # Internal Packages
from src.search_type import asymmetric, image_search from src.search_type import asymmetric, image_search
from src.utils.rawconfig import ContentTypeModel, ImageSearchConfigModel, TextSearchConfigModel from src.utils.rawconfig import ContentTypeConfig, ImageSearchConfig, TextSearchConfig
@pytest.fixture(scope='session') @pytest.fixture(scope='session')
@ -12,8 +12,8 @@ def model_dir(tmp_path_factory):
model_dir = tmp_path_factory.mktemp('data') model_dir = tmp_path_factory.mktemp('data')
# Generate Image Embeddings from Test Images # Generate Image Embeddings from Test Images
search_config = ContentTypeModel() search_config = ContentTypeConfig()
search_config.image = ImageSearchConfigModel( search_config.image = ImageSearchConfig(
input_directory = Path('tests/data'), input_directory = Path('tests/data'),
embeddings_file = model_dir.joinpath('.image_embeddings.pt'), embeddings_file = model_dir.joinpath('.image_embeddings.pt'),
batch_size = 10, batch_size = 10,
@ -22,7 +22,7 @@ def model_dir(tmp_path_factory):
image_search.setup(search_config.image, regenerate=False, verbose=True) image_search.setup(search_config.image, regenerate=False, verbose=True)
# Generate Notes Embeddings from Test Notes # Generate Notes Embeddings from Test Notes
search_config.org = TextSearchConfigModel( search_config.org = TextSearchConfig(
input_files = [Path('tests/data/main_readme.org'), Path('tests/data/interface_emacs_readme.org')], input_files = [Path('tests/data/main_readme.org'), Path('tests/data/interface_emacs_readme.org')],
input_filter = None, input_filter = None,
compressed_jsonl = model_dir.joinpath('.notes.jsonl.gz'), compressed_jsonl = model_dir.joinpath('.notes.jsonl.gz'),
@ -35,14 +35,14 @@ def model_dir(tmp_path_factory):
@pytest.fixture(scope='session') @pytest.fixture(scope='session')
def search_config(model_dir): def search_config(model_dir):
search_config = ContentTypeModel() search_config = ContentTypeConfig()
search_config.org = TextSearchConfigModel( search_config.org = TextSearchConfig(
input_files = [Path('tests/data/main_readme.org'), Path('tests/data/interface_emacs_readme.org')], input_files = [Path('tests/data/main_readme.org'), Path('tests/data/interface_emacs_readme.org')],
input_filter = None, input_filter = None,
compressed_jsonl = model_dir.joinpath('.notes.jsonl.gz'), compressed_jsonl = model_dir.joinpath('.notes.jsonl.gz'),
embeddings_file = model_dir.joinpath('.note_embeddings.pt')) embeddings_file = model_dir.joinpath('.note_embeddings.pt'))
search_config.image = ImageSearchConfigModel( search_config.image = ImageSearchConfig(
input_directory = Path('tests/data'), input_directory = Path('tests/data'),
embeddings_file = Path('tests/data/.image_embeddings.pt'), embeddings_file = Path('tests/data/.image_embeddings.pt'),
batch_size = 10, batch_size = 10,

View file

@ -8,13 +8,13 @@ from fastapi.testclient import TestClient
from src.main import app, model, config from src.main import app, model, config
from src.search_type import asymmetric, image_search from src.search_type import asymmetric, image_search
from src.utils.helpers import resolve_absolute_path from src.utils.helpers import resolve_absolute_path
from src.utils.rawconfig import FullConfigModel from src.utils.rawconfig import FullConfig
# Arrange # Arrange
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
client = TestClient(app) client = TestClient(app)
config = FullConfigModel() config = FullConfig()
# Test # Test
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------