From 3328a41f08b635ed1f2742de4059f8339f95ebba Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 17 Nov 2023 23:08:52 -0800 Subject: [PATCH 1/3] Update types of base config models for pydantic 2.0 --- src/khoj/routers/indexer.py | 7 ++++-- src/khoj/search_type/image_search.py | 4 ++-- src/khoj/utils/rawconfig.py | 36 ++++++++++++++-------------- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py index a7a1249d..4f9ddb30 100644 --- a/src/khoj/routers/indexer.py +++ b/src/khoj/routers/indexer.py @@ -63,7 +63,7 @@ async def update( request: Request, files: list[UploadFile], force: bool = False, - t: Optional[Union[state.SearchType, str]] = None, + t: Optional[Union[state.SearchType, str]] = state.SearchType.All, client: Optional[str] = None, user_agent: Optional[str] = Header(None), referer: Optional[str] = Header(None), @@ -182,13 +182,16 @@ def configure_content( files: Optional[dict[str, dict[str, str]]], search_models: SearchModels, regenerate: bool = False, - t: Optional[state.SearchType] = None, + t: Optional[state.SearchType] = state.SearchType.All, full_corpus: bool = True, user: KhojUser = None, ) -> tuple[Optional[ContentIndex], bool]: content_index = ContentIndex() success = True + if t is not None and t in [type.value for type in state.SearchType]: + t = state.SearchType(t) + if t is not None and not t.value in [type.value for type in state.SearchType]: logger.warning(f"🚨 Invalid search type: {t}") return None, False diff --git a/src/khoj/search_type/image_search.py b/src/khoj/search_type/image_search.py index 214118fc..8c0a3cdb 100644 --- a/src/khoj/search_type/image_search.py +++ b/src/khoj/search_type/image_search.py @@ -229,7 +229,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count= # Add the image metadata to the results results += [ - SearchResponse.parse_obj( + SearchResponse.model_validate( { "entry": f"{image_files_url}/{target_image_name}", "score": f"{hit['score']:.9f}", @@ -237,7 +237,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count= "image_score": f"{hit['image_score']:.9f}", "metadata_score": f"{hit['metadata_score']:.9f}", }, - "corpus_id": hit["corpus_id"], + "corpus_id": str(hit["corpus_id"]), } ) ] diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index 67016bf7..4c97aedd 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -14,7 +14,7 @@ from khoj.utils.helpers import to_snake_case_from_dash class ConfigBase(BaseModel): class Config: alias_generator = to_snake_case_from_dash - allow_population_by_field_name = True + populate_by_name = True def __getitem__(self, item): return getattr(self, item) @@ -29,8 +29,8 @@ class TextConfigBase(ConfigBase): class TextContentConfig(ConfigBase): - input_files: Optional[List[Path]] - input_filter: Optional[List[str]] + input_files: Optional[List[Path]] = None + input_filter: Optional[List[str]] = None index_heading_entries: Optional[bool] = False @@ -50,31 +50,31 @@ class NotionContentConfig(ConfigBase): class ImageContentConfig(ConfigBase): - input_directories: Optional[List[Path]] - input_filter: Optional[List[str]] + input_directories: Optional[List[Path]] = None + input_filter: Optional[List[str]] = None embeddings_file: Path use_xmp_metadata: bool batch_size: int class ContentConfig(ConfigBase): - org: Optional[TextContentConfig] - image: Optional[ImageContentConfig] - markdown: Optional[TextContentConfig] - pdf: Optional[TextContentConfig] - plaintext: Optional[TextContentConfig] - github: Optional[GithubContentConfig] - notion: Optional[NotionContentConfig] + org: Optional[TextContentConfig] = None + image: Optional[ImageContentConfig] = None + markdown: Optional[TextContentConfig] = None + pdf: Optional[TextContentConfig] = None + plaintext: Optional[TextContentConfig] = None + github: Optional[GithubContentConfig] = None + notion: Optional[NotionContentConfig] = None class ImageSearchConfig(ConfigBase): encoder: str - encoder_type: Optional[str] - model_directory: Optional[Path] + encoder_type: Optional[str] = None + model_directory: Optional[Path] = None class SearchConfig(ConfigBase): - image: Optional[ImageSearchConfig] + image: Optional[ImageSearchConfig] = None class OpenAIProcessorConfig(ConfigBase): @@ -95,7 +95,7 @@ class ConversationProcessorConfig(ConfigBase): class ProcessorConfig(ConfigBase): - conversation: Optional[ConversationProcessorConfig] + conversation: Optional[ConversationProcessorConfig] = None class AppConfig(ConfigBase): @@ -113,8 +113,8 @@ class FullConfig(ConfigBase): class SearchResponse(ConfigBase): entry: str score: float - cross_score: Optional[float] - additional: Optional[dict] + cross_score: Optional[float] = None + additional: Optional[dict] = None corpus_id: str From f180b2ba94eafe6e6eed8b586f54c873430253b2 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 17 Nov 2023 23:26:15 -0800 Subject: [PATCH 2/3] Resolve mypy errors for various data types --- src/khoj/processor/conversation/utils.py | 11 +++++++---- src/khoj/routers/api.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b0d401fa..ecd4f8ad 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -151,17 +151,20 @@ def truncate_messages( ) system_message = messages.pop() + assert type(system_message.content) == str system_message_tokens = len(encoder.encode(system_message.content)) - tokens = sum([len(encoder.encode(message.content)) for message in messages]) + tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1: messages.pop() - tokens = sum([len(encoder.encode(message.content)) for message in messages]) + assert type(system_message.content) == str + tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) # Truncate current message if still over max supported prompt size by model if (tokens + system_message_tokens) > max_prompt_size: - current_message = "\n".join(messages[0].content.split("\n")[:-1]) - original_question = "\n".join(messages[0].content.split("\n")[-1:]) + assert type(system_message.content) == str + current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else "" + original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else "" original_question_tokens = len(encoder.encode(original_question)) remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip() diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index be2643bd..b384d8a3 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -296,7 +296,7 @@ async def get_all_filenames( client=client, ) - return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) + return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) # type: ignore[call-arg] @api.post("/config/data/conversation/model", status_code=200) From 6d249645a6dadbf7a8a7acb09d94c20063a7de17 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 18 Nov 2023 00:04:18 -0800 Subject: [PATCH 3/3] Fix interpretation of the default search type --- src/khoj/routers/indexer.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py index 4f9ddb30..ccb65063 100644 --- a/src/khoj/routers/indexer.py +++ b/src/khoj/routers/indexer.py @@ -204,7 +204,7 @@ def configure_content( try: # Initialize Org Notes Search - if (search_type == None or search_type == state.SearchType.Org.value) and files["org"]: + if (search_type == state.SearchType.All.value or search_type == state.SearchType.Org.value) and files["org"]: logger.info("🦄 Setting up search for orgmode notes") # Extract Entries, Generate Notes Embeddings text_search.setup( @@ -220,7 +220,9 @@ def configure_content( try: # Initialize Markdown Search - if (search_type == None or search_type == state.SearchType.Markdown.value) and files["markdown"]: + if (search_type == state.SearchType.All.value or search_type == state.SearchType.Markdown.value) and files[ + "markdown" + ]: logger.info("💎 Setting up search for markdown notes") # Extract Entries, Generate Markdown Embeddings text_search.setup( @@ -237,7 +239,7 @@ def configure_content( try: # Initialize PDF Search - if (search_type == None or search_type == state.SearchType.Pdf.value) and files["pdf"]: + if (search_type == state.SearchType.All.value or search_type == state.SearchType.Pdf.value) and files["pdf"]: logger.info("🖨️ Setting up search for pdf") # Extract Entries, Generate PDF Embeddings text_search.setup( @@ -254,7 +256,9 @@ def configure_content( try: # Initialize Plaintext Search - if (search_type == None or search_type == state.SearchType.Plaintext.value) and files["plaintext"]: + if (search_type == state.SearchType.All.value or search_type == state.SearchType.Plaintext.value) and files[ + "plaintext" + ]: logger.info("📄 Setting up search for plaintext") # Extract Entries, Generate Plaintext Embeddings text_search.setup( @@ -272,7 +276,7 @@ def configure_content( try: # Initialize Image Search if ( - (search_type == None or search_type == state.SearchType.Image.value) + (search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value) and content_config and content_config.image and search_models.image_search @@ -289,7 +293,9 @@ def configure_content( try: github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() - if (search_type == None or search_type == state.SearchType.Github.value) and github_config is not None: + if ( + search_type == state.SearchType.All.value or search_type == state.SearchType.Github.value + ) and github_config is not None: logger.info("🐙 Setting up search for github") # Extract Entries, Generate Github Embeddings text_search.setup( @@ -308,7 +314,9 @@ def configure_content( try: # Initialize Notion Search notion_config = NotionConfig.objects.filter(user=user).first() - if (search_type == None or search_type in state.SearchType.Notion.value) and notion_config: + if ( + search_type == state.SearchType.All.value or search_type in state.SearchType.Notion.value + ) and notion_config: logger.info("🔌 Setting up search for notion") text_search.setup( NotionToEntries,