From 179153dc5ad650be33f6162a17267132f76c98fb Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 14 Jan 2022 20:54:38 -0500 Subject: [PATCH] Rename RawConfig Types for Consistency - Naming convention - [ContentType][ConfigType]Config - Where [ConfigType] ~ Content, Search, Processor - Where [ContentType] ~ Text, Image, Asymmetric, Symmetric, Conversation - Current Configs: - Content: - Org Notes - Org Music - Image - Ledger/Beancount - Search: - Asymmetric - Symmetric - Image - Processor: - Conversation --- src/search_type/asymmetric.py | 6 +-- src/search_type/image_search.py | 6 +-- src/search_type/symmetric_ledger.py | 6 +-- src/utils/rawconfig.py | 63 +++++++++++++---------------- tests/conftest.py | 24 +++++------ tests/test_asymmetric_search.py | 6 +-- tests/test_client.py | 14 +++---- tests/test_image_search.py | 6 +-- 8 files changed, 62 insertions(+), 69 deletions(-) diff --git a/src/search_type/asymmetric.py b/src/search_type/asymmetric.py index 670cc39e..2c06176d 100644 --- a/src/search_type/asymmetric.py +++ b/src/search_type/asymmetric.py @@ -15,10 +15,10 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model from src.processor.org_mode.org_to_jsonl import org_to_jsonl from src.utils.config import TextSearchModel -from src.utils.rawconfig import AsymmetricConfig, TextSearchConfig +from src.utils.rawconfig import AsymmetricSearchConfig, TextContentConfig -def initialize_model(search_config: AsymmetricConfig): +def initialize_model(search_config: AsymmetricSearchConfig): "Initialize model for assymetric semantic search. That is, where query smaller than results" torch.set_num_threads(4) @@ -162,7 +162,7 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(config: TextSearchConfig, search_config: AsymmetricConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel: +def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config) diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index 9f4d04f2..77ac52f2 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -13,10 +13,10 @@ import torch from src.utils.helpers import resolve_absolute_path, load_model import src.utils.exiftool as exiftool from src.utils.config import ImageSearchModel -from src.utils.rawconfig import ImageSearchConfig, ImageSearchTypeConfig +from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig -def initialize_model(search_config: ImageSearchTypeConfig): +def initialize_model(search_config: ImageSearchConfig): # Initialize Model torch.set_num_threads(4) @@ -160,7 +160,7 @@ def collate_results(hits, image_names, image_directory, count=5): in hits[0:count]] -def setup(config: ImageSearchConfig, search_config: ImageSearchTypeConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel: +def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel: # Initialize Model encoder = initialize_model(search_config) diff --git a/src/search_type/symmetric_ledger.py b/src/search_type/symmetric_ledger.py index c7d9cc9a..f63a1c98 100644 --- a/src/search_type/symmetric_ledger.py +++ b/src/search_type/symmetric_ledger.py @@ -13,10 +13,10 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl from src.utils.config import TextSearchModel -from src.utils.rawconfig import SymmetricConfig, TextSearchConfig +from src.utils.rawconfig import SymmetricSearchConfig, TextContentConfig -def initialize_model(search_config: SymmetricConfig): +def initialize_model(search_config: SymmetricSearchConfig): "Initialize model for symmetric semantic search. That is, where query of similar size to results" torch.set_num_threads(4) @@ -154,7 +154,7 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(config: TextSearchConfig, search_config: SymmetricConfig, regenerate: bool, verbose: bool) -> TextSearchModel: +def setup(config: TextContentConfig, search_config: SymmetricSearchConfig, regenerate: bool, verbose: bool) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config) diff --git a/src/utils/rawconfig.py b/src/utils/rawconfig.py index bbb2de31..e87c22fd 100644 --- a/src/utils/rawconfig.py +++ b/src/utils/rawconfig.py @@ -13,57 +13,52 @@ class ConfigBase(BaseModel): alias_generator = to_snake_case_from_dash allow_population_by_field_name = True -class SearchConfig(ConfigBase): - input_files: Optional[List[str]] - input_filter: Optional[str] - embeddings_file: Optional[Path] - -class TextSearchConfig(ConfigBase): +class TextContentConfig(ConfigBase): compressed_jsonl: Optional[Path] input_files: Optional[List[str]] input_filter: Optional[str] embeddings_file: Optional[Path] -class ImageSearchConfig(ConfigBase): +class ImageContentConfig(ConfigBase): use_xmp_metadata: Optional[str] batch_size: Optional[int] input_directory: Optional[Path] input_filter: Optional[str] embeddings_file: Optional[Path] -class ContentTypeConfig(ConfigBase): - org: Optional[TextSearchConfig] - ledger: Optional[TextSearchConfig] +class ContentConfig(ConfigBase): + org: Optional[TextContentConfig] + ledger: Optional[TextContentConfig] + image: Optional[ImageContentConfig] + music: Optional[TextContentConfig] + +class SymmetricSearchConfig(ConfigBase): + encoder: Optional[str] + cross_encoder: Optional[str] + model_directory: Optional[Path] + +class AsymmetricSearchConfig(ConfigBase): + encoder: Optional[str] + cross_encoder: Optional[str] + model_directory: Optional[Path] + +class ImageSearchConfig(ConfigBase): + encoder: Optional[str] + model_directory: Optional[Path] + +class SearchConfig(ConfigBase): + asymmetric: Optional[AsymmetricSearchConfig] + symmetric: Optional[SymmetricSearchConfig] image: Optional[ImageSearchConfig] - music: Optional[TextSearchConfig] - -class SymmetricConfig(ConfigBase): - encoder: Optional[str] - cross_encoder: Optional[str] - model_directory: Optional[Path] - -class AsymmetricConfig(ConfigBase): - encoder: Optional[str] - cross_encoder: Optional[str] - model_directory: Optional[Path] - -class ImageSearchTypeConfig(ConfigBase): - encoder: Optional[str] - model_directory: Optional[Path] - -class SearchTypeConfig(ConfigBase): - asymmetric: Optional[AsymmetricConfig] - symmetric: Optional[SymmetricConfig] - image: Optional[ImageSearchTypeConfig] class ConversationProcessorConfig(ConfigBase): openai_api_key: Optional[str] conversation_logfile: Optional[str] -class ProcessorConfigModel(ConfigBase): +class ProcessorConfig(ConfigBase): conversation: Optional[ConversationProcessorConfig] class FullConfig(ConfigBase): - content_type: Optional[ContentTypeConfig] - search_type: Optional[SearchTypeConfig] - processor: Optional[ProcessorConfigModel] + content_type: Optional[ContentConfig] + search_type: Optional[SearchConfig] + processor: Optional[ProcessorConfig] diff --git a/tests/conftest.py b/tests/conftest.py index bc71068c..45eee757 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,32 +1,30 @@ # Standard Packages import pytest -from pathlib import Path -from src import search_type # Internal Packages from src.search_type import asymmetric, image_search -from src.utils.rawconfig import AsymmetricConfig, ContentTypeConfig, ImageSearchConfig, ImageSearchTypeConfig, SearchTypeConfig, SymmetricConfig, TextSearchConfig +from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, SymmetricSearchConfig, AsymmetricSearchConfig, ImageSearchConfig @pytest.fixture(scope='session') def search_config(tmp_path_factory): model_dir = tmp_path_factory.mktemp('data') - search_config = SearchTypeConfig() + search_config = SearchConfig() - search_config.asymmetric = SymmetricConfig( + search_config.asymmetric = SymmetricSearchConfig( encoder = "sentence-transformers/paraphrase-MiniLM-L6-v2", cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2", model_directory = model_dir ) - search_config.asymmetric = AsymmetricConfig( + search_config.asymmetric = AsymmetricSearchConfig( encoder = "sentence-transformers/msmarco-MiniLM-L-6-v3", cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2", model_directory = model_dir ) - search_config.image = ImageSearchTypeConfig( + search_config.image = ImageSearchConfig( encoder = "clip-ViT-B-32", model_directory = model_dir ) @@ -39,8 +37,8 @@ def model_dir(search_config): model_dir = search_config.asymmetric.model_directory # Generate Image Embeddings from Test Images - content_config = ContentTypeConfig() - content_config.image = ImageSearchConfig( + content_config = ContentConfig() + content_config.image = ImageContentConfig( input_directory = 'tests/data', embeddings_file = model_dir.joinpath('.image_embeddings.pt'), batch_size = 10, @@ -49,7 +47,7 @@ def model_dir(search_config): image_search.setup(content_config.image, search_config.image, regenerate=False, verbose=True) # Generate Notes Embeddings from Test Notes - content_config.org = TextSearchConfig( + content_config.org = TextContentConfig( input_files = ['tests/data/main_readme.org', 'tests/data/interface_emacs_readme.org'], input_filter = None, compressed_jsonl = model_dir.joinpath('.notes.jsonl.gz'), @@ -62,14 +60,14 @@ def model_dir(search_config): @pytest.fixture(scope='session') def content_config(model_dir): - content_config = ContentTypeConfig() - content_config.org = TextSearchConfig( + content_config = ContentConfig() + content_config.org = TextContentConfig( input_files = ['tests/data/main_readme.org', 'tests/data/interface_emacs_readme.org'], input_filter = None, compressed_jsonl = model_dir.joinpath('.notes.jsonl.gz'), embeddings_file = model_dir.joinpath('.note_embeddings.pt')) - content_config.image = ImageSearchConfig( + content_config.image = ImageContentConfig( input_directory = 'tests/data', embeddings_file = model_dir.joinpath('.image_embeddings.pt'), batch_size = 10, diff --git a/tests/test_asymmetric_search.py b/tests/test_asymmetric_search.py index 14666002..78b1df9c 100644 --- a/tests/test_asymmetric_search.py +++ b/tests/test_asymmetric_search.py @@ -1,12 +1,12 @@ # Internal Packages from src.main import model from src.search_type import asymmetric -from src.utils.rawconfig import ContentTypeConfig, SearchTypeConfig +from src.utils.rawconfig import ContentConfig, SearchConfig # Test # ---------------------------------------------------------------------------------------------------- -def test_asymmetric_setup(content_config: ContentTypeConfig, search_config: SearchTypeConfig): +def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig): # Act # Regenerate notes embeddings during asymmetric setup notes_model = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=True) @@ -17,7 +17,7 @@ def test_asymmetric_setup(content_config: ContentTypeConfig, search_config: Sear # ---------------------------------------------------------------------------------------------------- -def test_asymmetric_search(content_config: ContentTypeConfig, search_config: SearchTypeConfig): +def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False) query = "How to git install application?" diff --git a/tests/test_client.py b/tests/test_client.py index 9b695550..2f00a745 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,7 +9,7 @@ import pytest from src.main import app, model, config from src.search_type import asymmetric, image_search from src.utils.helpers import resolve_absolute_path -from src.utils.rawconfig import ContentTypeConfig, SearchTypeConfig +from src.utils.rawconfig import ContentConfig, SearchConfig # Arrange @@ -30,7 +30,7 @@ def test_search_with_invalid_content_type(): # ---------------------------------------------------------------------------------------------------- -def test_search_with_valid_content_type(content_config: ContentTypeConfig, search_config: SearchTypeConfig): +def test_search_with_valid_content_type(content_config: ContentConfig, search_config: SearchConfig): # Arrange config.content_type = content_config config.search_type = search_config @@ -53,7 +53,7 @@ def test_regenerate_with_invalid_content_type(): # ---------------------------------------------------------------------------------------------------- -def test_regenerate_with_valid_content_type(content_config: ContentTypeConfig, search_config: SearchTypeConfig): +def test_regenerate_with_valid_content_type(content_config: ContentConfig, search_config: SearchConfig): # Arrange config.content_type = content_config config.search_type = search_config @@ -67,7 +67,7 @@ def test_regenerate_with_valid_content_type(content_config: ContentTypeConfig, s # ---------------------------------------------------------------------------------------------------- @pytest.mark.skip(reason="Flaky test. Search doesn't always return expected image path.") -def test_image_search(content_config: ContentTypeConfig, search_config: SearchTypeConfig): +def test_image_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange config.content_type = content_config config.search_type = search_config @@ -90,7 +90,7 @@ def test_image_search(content_config: ContentTypeConfig, search_config: SearchTy # ---------------------------------------------------------------------------------------------------- -def test_notes_search(content_config: ContentTypeConfig, search_config: SearchTypeConfig): +def test_notes_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False) user_query = "How to git install application?" @@ -106,7 +106,7 @@ def test_notes_search(content_config: ContentTypeConfig, search_config: SearchTy # ---------------------------------------------------------------------------------------------------- -def test_notes_search_with_include_filter(content_config: ContentTypeConfig, search_config: SearchTypeConfig): +def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False) user_query = "How to git install application? +Emacs" @@ -122,7 +122,7 @@ def test_notes_search_with_include_filter(content_config: ContentTypeConfig, sea # ---------------------------------------------------------------------------------------------------- -def test_notes_search_with_exclude_filter(content_config: ContentTypeConfig, search_config: SearchTypeConfig): +def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False) user_query = "How to git install application? -clone" diff --git a/tests/test_image_search.py b/tests/test_image_search.py index 254bcb7f..bb01ebe4 100644 --- a/tests/test_image_search.py +++ b/tests/test_image_search.py @@ -5,12 +5,12 @@ import pytest from src.main import model from src.search_type import image_search from src.utils.helpers import resolve_absolute_path -from src.utils.rawconfig import ContentTypeConfig, SearchTypeConfig +from src.utils.rawconfig import ContentConfig, SearchConfig # Test # ---------------------------------------------------------------------------------------------------- -def test_image_search_setup(content_config: ContentTypeConfig, search_config: SearchTypeConfig): +def test_image_search_setup(content_config: ContentConfig, search_config: SearchConfig): # Act # Regenerate image search embeddings during image setup image_search_model = image_search.setup(content_config.image, search_config.image, regenerate=True) @@ -22,7 +22,7 @@ def test_image_search_setup(content_config: ContentTypeConfig, search_config: Se # ---------------------------------------------------------------------------------------------------- @pytest.mark.skip(reason="results inconsistent currently") -def test_image_search(content_config: ContentTypeConfig, search_config: SearchTypeConfig): +def test_image_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) query_expected_image_pairs = [("brown kitten next to plant", "kitten_park.jpg"),