From ff5c10c221fa7e668d404293715c55388eb06c00 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Mon, 4 Nov 2024 19:45:28 -0800 Subject: [PATCH] Do not CRUD on entries, files & conversations in DB for null user Increase defense-in-depth by reducing paths to create, read, update or delete entries, files and conversations in DB when user is unset. --- src/khoj/configure.py | 10 ++- src/khoj/database/adapters/__init__.py | 64 ++++++++++++++++--- .../processor/content/docx/docx_to_entries.py | 4 +- .../content/github/github_to_entries.py | 4 +- .../content/images/image_to_entries.py | 4 +- .../content/markdown/markdown_to_entries.py | 4 +- .../content/notion/notion_to_entries.py | 4 +- .../content/org_mode/org_to_entries.py | 4 +- .../processor/content/pdf/pdf_to_entries.py | 4 +- .../content/plaintext/plaintext_to_entries.py | 4 +- src/khoj/processor/content/text_to_entries.py | 4 +- src/khoj/routers/api.py | 2 +- src/khoj/routers/api_content.py | 4 +- src/khoj/routers/helpers.py | 4 +- src/khoj/routers/notion.py | 2 +- src/khoj/search_type/text_search.py | 2 +- src/khoj/utils/fs_syncer.py | 3 +- tests/conftest.py | 6 +- tests/test_client.py | 6 +- 19 files changed, 92 insertions(+), 47 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index a1f4a7db..002413b8 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -253,7 +253,7 @@ def configure_server( logger.info(message) if not init: - initialize_content(regenerate, search_type, user) + initialize_content(user, regenerate, search_type) except Exception as e: logger.error(f"Failed to load some search models: {e}", exc_info=True) @@ -263,17 +263,17 @@ def setup_default_agent(user: KhojUser): AgentAdapters.create_default_agent(user) -def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None): +def initialize_content(user: KhojUser, regenerate: bool, search_type: Optional[SearchType] = None): # Initialize Content from Config if state.search_models: try: logger.info("📬 Updating content index...") all_files = collect_files(user=user) status = configure_content( + user, all_files, regenerate, search_type, - user=user, ) if not status: raise RuntimeError("Failed to update content index") @@ -338,9 +338,7 @@ def configure_middleware(app): def update_content_index(): for user in get_all_users(): all_files = collect_files(user=user) - success = configure_content(all_files, user=user) - all_files = collect_files(user=None) - success = configure_content(all_files, user=None) + success = configure_content(user, all_files) if not success: raise RuntimeError("Failed to update content index") logger.info("📪 Content index updated via Scheduler") diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 6676eefa..ad149d58 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -416,6 +416,11 @@ def get_all_users() -> BaseManager[KhojUser]: return KhojUser.objects.all() +def check_valid_user(user: KhojUser | None): + if not user: + raise ValueError("User not found") + + def get_user_github_config(user: KhojUser): config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() return config @@ -815,6 +820,7 @@ class ConversationAdapters: def get_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None ) -> Optional[Conversation]: + check_valid_user(user) if conversation_id: conversation = ( Conversation.objects.filter(user=user, client=client_application, id=conversation_id) @@ -831,6 +837,7 @@ class ConversationAdapters: @staticmethod def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None): + check_valid_user(user) return ( Conversation.objects.filter(user=user, client=client_application) .prefetch_related("agent") @@ -841,6 +848,7 @@ class ConversationAdapters: async def aset_conversation_title( user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str ): + check_valid_user(user) conversation = await Conversation.objects.filter( user=user, client=client_application, id=conversation_id ).afirst() @@ -858,6 +866,7 @@ class ConversationAdapters: async def acreate_conversation_session( user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None ): + check_valid_user(user) if agent_slug: agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user) if agent is None: @@ -874,6 +883,7 @@ class ConversationAdapters: def create_conversation_session( user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None ): + check_valid_user(user) if agent_slug: agent = AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user) if agent is None: @@ -890,6 +900,7 @@ class ConversationAdapters: title: str = None, create_new: bool = False, ) -> Optional[Conversation]: + check_valid_user(user) if create_new: return await ConversationAdapters.acreate_conversation_session(user, client_application) @@ -910,12 +921,14 @@ class ConversationAdapters: async def adelete_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None ): + check_valid_user(user) if conversation_id: return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).adelete() return await Conversation.objects.filter(user=user, client=client_application).adelete() @staticmethod def has_any_conversation_config(user: KhojUser): + check_valid_user(user) return ChatModelOptions.objects.filter(user=user).exists() @staticmethod @@ -953,7 +966,7 @@ class ConversationAdapters: @staticmethod async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int): config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst() - if not config: + if not config or user is None: return None new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) return new_config @@ -961,7 +974,7 @@ class ConversationAdapters: @staticmethod async def aset_user_voice_model(user: KhojUser, model_id: str): config = await VoiceModelOption.objects.filter(model_id=model_id).afirst() - if not config: + if not config or user is None: return None new_config = await UserVoiceModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) return new_config @@ -1146,6 +1159,7 @@ class ConversationAdapters: def create_conversation_from_public_conversation( user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication ): + check_valid_user(user) scrubbed_title = public_conversation.title if public_conversation.title else public_conversation.slug if scrubbed_title: scrubbed_title = scrubbed_title.replace("-", " ") @@ -1166,6 +1180,7 @@ class ConversationAdapters: conversation_id: str = None, user_message: str = None, ): + check_valid_user(user) slug = user_message.strip()[:200] if user_message else None if conversation_id: conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first() @@ -1209,6 +1224,7 @@ class ConversationAdapters: @staticmethod async def aget_conversation_starters(user: KhojUser, max_results=3): + check_valid_user(user) all_questions = [] if await ReflectiveQuestion.objects.filter(user=user).aexists(): all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=user).values_list)( @@ -1338,6 +1354,7 @@ class ConversationAdapters: @staticmethod def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str): + check_valid_user(user) conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id) if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"): return False @@ -1356,51 +1373,62 @@ class FileObjectAdapters: @staticmethod def create_file_object(user: KhojUser, file_name: str, raw_text: str): + check_valid_user(user) return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text) @staticmethod def get_file_object_by_name(user: KhojUser, file_name: str): + check_valid_user(user) return FileObject.objects.filter(user=user, file_name=file_name).first() @staticmethod def get_all_file_objects(user: KhojUser): + check_valid_user(user) return FileObject.objects.filter(user=user).all() @staticmethod def delete_file_object_by_name(user: KhojUser, file_name: str): + check_valid_user(user) return FileObject.objects.filter(user=user, file_name=file_name).delete() @staticmethod def delete_all_file_objects(user: KhojUser): + check_valid_user(user) return FileObject.objects.filter(user=user).delete() @staticmethod - async def async_update_raw_text(file_object: FileObject, new_raw_text: str): + async def aupdate_raw_text(file_object: FileObject, new_raw_text: str): file_object.raw_text = new_raw_text await file_object.asave() @staticmethod - async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str): + async def acreate_file_object(user: KhojUser, file_name: str, raw_text: str): + check_valid_user(user) return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text) @staticmethod - async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): + async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): + check_valid_user(user) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent)) @staticmethod - async def async_get_file_objects_by_names(user: KhojUser, file_names: List[str]): + async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]): + check_valid_user(user) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names)) @staticmethod - async def async_get_all_file_objects(user: KhojUser): + async def aget_all_file_objects(user: KhojUser): + check_valid_user(user) return await sync_to_async(list)(FileObject.objects.filter(user=user)) @staticmethod - async def async_delete_file_object_by_name(user: KhojUser, file_name: str): + async def adelete_file_object_by_name(user: KhojUser, file_name: str): + check_valid_user(user) return await FileObject.objects.filter(user=user, file_name=file_name).adelete() @staticmethod - async def async_delete_all_file_objects(user: KhojUser): + async def adelete_all_file_objects(user: KhojUser): + check_valid_user(user) return await FileObject.objects.filter(user=user).adelete() @@ -1411,15 +1439,18 @@ class EntryAdapters: @staticmethod def does_entry_exist(user: KhojUser, hashed_value: str) -> bool: + check_valid_user(user) return Entry.objects.filter(user=user, hashed_value=hashed_value).exists() @staticmethod def delete_entry_by_file(user: KhojUser, file_path: str): + check_valid_user(user) deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete() return deleted_count @staticmethod def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None): + check_valid_user(user) queryset = Entry.objects.filter(user=user) if file_type is not None: @@ -1432,6 +1463,7 @@ class EntryAdapters: @staticmethod def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): + check_valid_user(user) deleted_count = 0 queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) while queryset.exists(): @@ -1443,6 +1475,7 @@ class EntryAdapters: @staticmethod async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): + check_valid_user(user) deleted_count = 0 queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) while await queryset.aexists(): @@ -1454,10 +1487,12 @@ class EntryAdapters: @staticmethod def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): + check_valid_user(user) return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) @staticmethod def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]): + check_valid_user(user) Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete() @staticmethod @@ -1469,6 +1504,7 @@ class EntryAdapters: @staticmethod def user_has_entries(user: KhojUser): + check_valid_user(user) return Entry.objects.filter(user=user).exists() @staticmethod @@ -1477,6 +1513,7 @@ class EntryAdapters: @staticmethod async def auser_has_entries(user: KhojUser): + check_valid_user(user) return await Entry.objects.filter(user=user).aexists() @staticmethod @@ -1487,10 +1524,12 @@ class EntryAdapters: @staticmethod async def adelete_entry_by_file(user: KhojUser, file_path: str): + check_valid_user(user) return await Entry.objects.filter(user=user, file_path=file_path).adelete() @staticmethod async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000): + check_valid_user(user) deleted_count = 0 for i in range(0, len(filenames), batch_size): batch = filenames[i : i + batch_size] @@ -1509,6 +1548,7 @@ class EntryAdapters: @staticmethod def get_all_filenames_by_source(user: KhojUser, file_source: str): + check_valid_user(user) return ( Entry.objects.filter(user=user, file_source=file_source) .distinct("file_path") @@ -1517,6 +1557,7 @@ class EntryAdapters: @staticmethod def get_size_of_indexed_data_in_mb(user: KhojUser): + check_valid_user(user) entries = Entry.objects.filter(user=user).iterator() total_size = sum(sys.getsizeof(entry.compiled) for entry in entries) return total_size / 1024 / 1024 @@ -1536,6 +1577,9 @@ class EntryAdapters: if agent != None: owner_filter |= Q(agent=agent) + if owner_filter == Q(): + return Entry.objects.none() + if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0: return Entry.objects.filter(owner_filter) @@ -1611,10 +1655,12 @@ class EntryAdapters: @staticmethod def get_unique_file_types(user: KhojUser): + check_valid_user(user) return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct() @staticmethod def get_unique_file_sources(user: KhojUser): + check_valid_user(user) return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all() diff --git a/src/khoj/processor/content/docx/docx_to_entries.py b/src/khoj/processor/content/docx/docx_to_entries.py index 19d9ba13..35c634f7 100644 --- a/src/khoj/processor/content/docx/docx_to_entries.py +++ b/src/khoj/processor/content/docx/docx_to_entries.py @@ -18,7 +18,7 @@ class DocxToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config deletion_file_names = set([file for file in files if files[file] == b""]) files_to_process = set(files) - deletion_file_names @@ -35,13 +35,13 @@ class DocxToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.DOCX, DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, - user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/github/github_to_entries.py b/src/khoj/processor/content/github/github_to_entries.py index 1f3dea00..2381bea8 100644 --- a/src/khoj/processor/content/github/github_to_entries.py +++ b/src/khoj/processor/content/github/github_to_entries.py @@ -48,7 +48,7 @@ class GithubToEntries(TextToEntries): else: return - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: if self.config.pat_token is None or self.config.pat_token == "": logger.error(f"Github PAT token is not set. Skipping github content") raise ValueError("Github PAT token is not set. Skipping github content") @@ -101,12 +101,12 @@ class GithubToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.GITHUB, DbEntry.EntrySource.GITHUB, key="compiled", logger=logger, - user=user, ) return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/processor/content/images/image_to_entries.py b/src/khoj/processor/content/images/image_to_entries.py index 87b9a009..134cca52 100644 --- a/src/khoj/processor/content/images/image_to_entries.py +++ b/src/khoj/processor/content/images/image_to_entries.py @@ -18,7 +18,7 @@ class ImageToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config deletion_file_names = set([file for file in files if files[file] == b""]) files_to_process = set(files) - deletion_file_names @@ -35,13 +35,13 @@ class ImageToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.IMAGE, DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, - user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/markdown/markdown_to_entries.py b/src/khoj/processor/content/markdown/markdown_to_entries.py index fdb0c549..c4ee03ef 100644 --- a/src/khoj/processor/content/markdown/markdown_to_entries.py +++ b/src/khoj/processor/content/markdown/markdown_to_entries.py @@ -19,7 +19,7 @@ class MarkdownToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config deletion_file_names = set([file for file in files if files[file] == ""]) files_to_process = set(files) - deletion_file_names @@ -37,13 +37,13 @@ class MarkdownToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.MARKDOWN, DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, - user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/notion/notion_to_entries.py b/src/khoj/processor/content/notion/notion_to_entries.py index fc6e296f..1e1ab4d3 100644 --- a/src/khoj/processor/content/notion/notion_to_entries.py +++ b/src/khoj/processor/content/notion/notion_to_entries.py @@ -79,7 +79,7 @@ class NotionToEntries(TextToEntries): self.body_params = {"page_size": 100} - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: current_entries = [] # Get all pages @@ -248,12 +248,12 @@ class NotionToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.NOTION, DbEntry.EntrySource.NOTION, key="compiled", logger=logger, - user=user, ) return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/processor/content/org_mode/org_to_entries.py b/src/khoj/processor/content/org_mode/org_to_entries.py index 1272da11..cfc17cc0 100644 --- a/src/khoj/processor/content/org_mode/org_to_entries.py +++ b/src/khoj/processor/content/org_mode/org_to_entries.py @@ -20,7 +20,7 @@ class OrgToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: deletion_file_names = set([file for file in files if files[file] == ""]) files_to_process = set(files) - deletion_file_names files = {file: files[file] for file in files_to_process} @@ -36,13 +36,13 @@ class OrgToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.ORG, DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, - user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/pdf/pdf_to_entries.py b/src/khoj/processor/content/pdf/pdf_to_entries.py index f1ac5104..7d2bd384 100644 --- a/src/khoj/processor/content/pdf/pdf_to_entries.py +++ b/src/khoj/processor/content/pdf/pdf_to_entries.py @@ -19,7 +19,7 @@ class PdfToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config deletion_file_names = set([file for file in files if files[file] == b""]) files_to_process = set(files) - deletion_file_names @@ -36,13 +36,13 @@ class PdfToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.PDF, DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, - user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/plaintext/plaintext_to_entries.py b/src/khoj/processor/content/plaintext/plaintext_to_entries.py index 483e752f..64470c08 100644 --- a/src/khoj/processor/content/plaintext/plaintext_to_entries.py +++ b/src/khoj/processor/content/plaintext/plaintext_to_entries.py @@ -20,7 +20,7 @@ class PlaintextToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: deletion_file_names = set([file for file in files if files[file] == ""]) files_to_process = set(files) - deletion_file_names files = {file: files[file] for file in files_to_process} @@ -36,13 +36,13 @@ class PlaintextToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.PLAINTEXT, DbEntry.EntrySource.COMPUTER, key="compiled", logger=logger, deletion_filenames=deletion_file_names, - user=user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py index 181eb199..f013b28c 100644 --- a/src/khoj/processor/content/text_to_entries.py +++ b/src/khoj/processor/content/text_to_entries.py @@ -31,7 +31,7 @@ class TextToEntries(ABC): self.date_filter = DateFilter() @abstractmethod - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: ... @staticmethod @@ -114,13 +114,13 @@ class TextToEntries(ABC): def update_embeddings( self, + user: KhojUser, current_entries: List[Entry], file_type: str, file_source: str, key="compiled", logger: logging.Logger = None, deletion_filenames: Set[str] = None, - user: KhojUser = None, regenerate: bool = False, file_to_text_map: dict[str, str] = None, ): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index f66fbce8..fc7dfe27 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -212,7 +212,7 @@ def update( logger.warning(error_msg) raise HTTPException(status_code=500, detail=error_msg) try: - initialize_content(regenerate=force, search_type=t, user=user) + initialize_content(user=user, regenerate=force, search_type=t) except Exception as e: error_msg = f"🚨 Failed to update server via API: {e}" logger.error(error_msg, exc_info=True) diff --git a/src/khoj/routers/api_content.py b/src/khoj/routers/api_content.py index 40a1fb78..9ac0db47 100644 --- a/src/khoj/routers/api_content.py +++ b/src/khoj/routers/api_content.py @@ -239,7 +239,7 @@ async def set_content_notion( if updated_config.token: # Trigger an async job to configure_content. Let it run without blocking the response. - background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user) + background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion) update_telemetry_state( request=request, @@ -512,10 +512,10 @@ async def indexer( success = await loop.run_in_executor( None, configure_content, + user, indexer_input.model_dump(), regenerate, t, - user, ) if not success: raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index") diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 2d0bbe29..3a2cb5cf 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -703,7 +703,7 @@ async def generate_summary_from_files( if await EntryAdapters.aagent_has_entries(agent): file_names = await EntryAdapters.aget_agent_entry_filepaths(agent) if len(file_names) > 0: - file_objects = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent) + file_objects = await FileObjectAdapters.aget_file_objects_by_name(None, file_names.pop(), agent) if (file_objects and len(file_objects) == 0 and not query_files) or (not file_objects and not query_files): response_log = "Sorry, I couldn't find anything to summarize." @@ -1975,10 +1975,10 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) def configure_content( + user: KhojUser, files: Optional[dict[str, dict[str, str]]], regenerate: bool = False, t: Optional[state.SearchType] = state.SearchType.All, - user: KhojUser = None, ) -> bool: success = True if t == None: diff --git a/src/khoj/routers/notion.py b/src/khoj/routers/notion.py index 7d5ed25d..acfd1e2e 100644 --- a/src/khoj/routers/notion.py +++ b/src/khoj/routers/notion.py @@ -80,6 +80,6 @@ async def notion_auth_callback(request: Request, background_tasks: BackgroundTas notion_redirect = str(request.app.url_path_for("config_page")) # Trigger an async job to configure_content. Let it run without blocking the response. - background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user) + background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion) return RedirectResponse(notion_redirect) diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index eed72b51..6d7667e5 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -208,7 +208,7 @@ def setup( text_to_entries: Type[TextToEntries], files: dict[str, str], regenerate: bool, - user: KhojUser = None, + user: KhojUser, config=None, ) -> Tuple[int, int]: if config: diff --git a/src/khoj/utils/fs_syncer.py b/src/khoj/utils/fs_syncer.py index 475504f1..67e91bc9 100644 --- a/src/khoj/utils/fs_syncer.py +++ b/src/khoj/utils/fs_syncer.py @@ -8,6 +8,7 @@ from bs4 import BeautifulSoup from magika import Magika from khoj.database.models import ( + KhojUser, LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, @@ -21,7 +22,7 @@ logger = logging.getLogger(__name__) magika = Magika() -def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict: +def collect_files(user: KhojUser, search_type: Optional[SearchType] = SearchType.All) -> dict: files: dict[str, dict] = {"docx": {}, "image": {}} if search_type == SearchType.All or search_type == SearchType.Org: diff --git a/tests/conftest.py b/tests/conftest.py index 54b4db86..b91af758 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -304,7 +304,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa # Index Markdown Content for Search all_files = fs_syncer.collect_files(user=user) - success = configure_content(all_files, user=user) + configure_content(user, all_files) # Initialize Processor from Config if os.getenv("OPENAI_API_KEY"): @@ -381,7 +381,7 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser): ) all_files = fs_syncer.collect_files(user=default_user2) - configure_content(all_files, user=default_user2) + configure_content(default_user2, all_files) # Initialize Processor from Config ChatModelOptionsFactory( @@ -432,7 +432,7 @@ def pdf_configured_user1(default_user: KhojUser): ) # Index Markdown Content for Search all_files = fs_syncer.collect_files(user=default_user) - success = configure_content(all_files, user=default_user) + configure_content(default_user, all_files) @pytest.fixture(scope="function") diff --git a/tests/test_client.py b/tests/test_client.py index b8284e4b..f5ed320f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -253,11 +253,11 @@ def test_regenerate_with_github_fails_without_pat(client): # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_get_configured_types_via_api(client, sample_org_data): +def test_get_configured_types_via_api(client, sample_org_data, default_user3: KhojUser): # Act - text_search.setup(OrgToEntries, sample_org_data, regenerate=False) + text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user3) - enabled_types = EntryAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True) + enabled_types = EntryAdapters.get_unique_file_types(user=default_user3).all().values_list("file_type", flat=True) # Assert assert list(enabled_types) == ["org"]