From ff5c10c221fa7e668d404293715c55388eb06c00 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Mon, 4 Nov 2024 19:45:28 -0800 Subject: [PATCH 1/3] Do not CRUD on entries, files & conversations in DB for null user Increase defense-in-depth by reducing paths to create, read, update or delete entries, files and conversations in DB when user is unset. --- src/khoj/configure.py | 10 ++- src/khoj/database/adapters/__init__.py | 64 ++++++++++++++++--- .../processor/content/docx/docx_to_entries.py | 4 +- .../content/github/github_to_entries.py | 4 +- .../content/images/image_to_entries.py | 4 +- .../content/markdown/markdown_to_entries.py | 4 +- .../content/notion/notion_to_entries.py | 4 +- .../content/org_mode/org_to_entries.py | 4 +- .../processor/content/pdf/pdf_to_entries.py | 4 +- .../content/plaintext/plaintext_to_entries.py | 4 +- src/khoj/processor/content/text_to_entries.py | 4 +- src/khoj/routers/api.py | 2 +- src/khoj/routers/api_content.py | 4 +- src/khoj/routers/helpers.py | 4 +- src/khoj/routers/notion.py | 2 +- src/khoj/search_type/text_search.py | 2 +- src/khoj/utils/fs_syncer.py | 3 +- tests/conftest.py | 6 +- tests/test_client.py | 6 +- 19 files changed, 92 insertions(+), 47 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index a1f4a7db..002413b8 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -253,7 +253,7 @@ def configure_server( logger.info(message) if not init: - initialize_content(regenerate, search_type, user) + initialize_content(user, regenerate, search_type) except Exception as e: logger.error(f"Failed to load some search models: {e}", exc_info=True) @@ -263,17 +263,17 @@ def setup_default_agent(user: KhojUser): AgentAdapters.create_default_agent(user) -def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None): +def initialize_content(user: KhojUser, regenerate: bool, search_type: Optional[SearchType] = None): # Initialize Content from Config if state.search_models: try: logger.info("📬 Updating content index...") all_files = collect_files(user=user) status = configure_content( + user, all_files, regenerate, search_type, - user=user, ) if not status: raise RuntimeError("Failed to update content index") @@ -338,9 +338,7 @@ def configure_middleware(app): def update_content_index(): for user in get_all_users(): all_files = collect_files(user=user) - success = configure_content(all_files, user=user) - all_files = collect_files(user=None) - success = configure_content(all_files, user=None) + success = configure_content(user, all_files) if not success: raise RuntimeError("Failed to update content index") logger.info("📪 Content index updated via Scheduler") diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 6676eefa..ad149d58 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -416,6 +416,11 @@ def get_all_users() -> BaseManager[KhojUser]: return KhojUser.objects.all() +def check_valid_user(user: KhojUser | None): + if not user: + raise ValueError("User not found") + + def get_user_github_config(user: KhojUser): config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() return config @@ -815,6 +820,7 @@ class ConversationAdapters: def get_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None ) -> Optional[Conversation]: + check_valid_user(user) if conversation_id: conversation = ( Conversation.objects.filter(user=user, client=client_application, id=conversation_id) @@ -831,6 +837,7 @@ class ConversationAdapters: @staticmethod def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None): + check_valid_user(user) return ( Conversation.objects.filter(user=user, client=client_application) .prefetch_related("agent") @@ -841,6 +848,7 @@ class ConversationAdapters: async def aset_conversation_title( user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str ): + check_valid_user(user) conversation = await Conversation.objects.filter( user=user, client=client_application, id=conversation_id ).afirst() @@ -858,6 +866,7 @@ class ConversationAdapters: async def acreate_conversation_session( user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None ): + check_valid_user(user) if agent_slug: agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user) if agent is None: @@ -874,6 +883,7 @@ class ConversationAdapters: def create_conversation_session( user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None ): + check_valid_user(user) if agent_slug: agent = AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user) if agent is None: @@ -890,6 +900,7 @@ class ConversationAdapters: title: str = None, create_new: bool = False, ) -> Optional[Conversation]: + check_valid_user(user) if create_new: return await ConversationAdapters.acreate_conversation_session(user, client_application) @@ -910,12 +921,14 @@ class ConversationAdapters: async def adelete_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None ): + check_valid_user(user) if conversation_id: return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).adelete() return await Conversation.objects.filter(user=user, client=client_application).adelete() @staticmethod def has_any_conversation_config(user: KhojUser): + check_valid_user(user) return ChatModelOptions.objects.filter(user=user).exists() @staticmethod @@ -953,7 +966,7 @@ class ConversationAdapters: @staticmethod async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int): config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst() - if not config: + if not config or user is None: return None new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) return new_config @@ -961,7 +974,7 @@ class ConversationAdapters: @staticmethod async def aset_user_voice_model(user: KhojUser, model_id: str): config = await VoiceModelOption.objects.filter(model_id=model_id).afirst() - if not config: + if not config or user is None: return None new_config = await UserVoiceModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) return new_config @@ -1146,6 +1159,7 @@ class ConversationAdapters: def create_conversation_from_public_conversation( user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication ): + check_valid_user(user) scrubbed_title = public_conversation.title if public_conversation.title else public_conversation.slug if scrubbed_title: scrubbed_title = scrubbed_title.replace("-", " ") @@ -1166,6 +1180,7 @@ class ConversationAdapters: conversation_id: str = None, user_message: str = None, ): + check_valid_user(user) slug = user_message.strip()[:200] if user_message else None if conversation_id: conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first() @@ -1209,6 +1224,7 @@ class ConversationAdapters: @staticmethod async def aget_conversation_starters(user: KhojUser, max_results=3): + check_valid_user(user) all_questions = [] if await ReflectiveQuestion.objects.filter(user=user).aexists(): all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=user).values_list)( @@ -1338,6 +1354,7 @@ class ConversationAdapters: @staticmethod def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str): + check_valid_user(user) conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id) if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"): return False @@ -1356,51 +1373,62 @@ class FileObjectAdapters: @staticmethod def create_file_object(user: KhojUser, file_name: str, raw_text: str): + check_valid_user(user) return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text) @staticmethod def get_file_object_by_name(user: KhojUser, file_name: str): + check_valid_user(user) return FileObject.objects.filter(user=user, file_name=file_name).first() @staticmethod def get_all_file_objects(user: KhojUser): + check_valid_user(user) return FileObject.objects.filter(user=user).all() @staticmethod def delete_file_object_by_name(user: KhojUser, file_name: str): + check_valid_user(user) return FileObject.objects.filter(user=user, file_name=file_name).delete() @staticmethod def delete_all_file_objects(user: KhojUser): + check_valid_user(user) return FileObject.objects.filter(user=user).delete() @staticmethod - async def async_update_raw_text(file_object: FileObject, new_raw_text: str): + async def aupdate_raw_text(file_object: FileObject, new_raw_text: str): file_object.raw_text = new_raw_text await file_object.asave() @staticmethod - async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str): + async def acreate_file_object(user: KhojUser, file_name: str, raw_text: str): + check_valid_user(user) return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text) @staticmethod - async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): + async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): + check_valid_user(user) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent)) @staticmethod - async def async_get_file_objects_by_names(user: KhojUser, file_names: List[str]): + async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]): + check_valid_user(user) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names)) @staticmethod - async def async_get_all_file_objects(user: KhojUser): + async def aget_all_file_objects(user: KhojUser): + check_valid_user(user) return await sync_to_async(list)(FileObject.objects.filter(user=user)) @staticmethod - async def async_delete_file_object_by_name(user: KhojUser, file_name: str): + async def adelete_file_object_by_name(user: KhojUser, file_name: str): + check_valid_user(user) return await FileObject.objects.filter(user=user, file_name=file_name).adelete() @staticmethod - async def async_delete_all_file_objects(user: KhojUser): + async def adelete_all_file_objects(user: KhojUser): + check_valid_user(user) return await FileObject.objects.filter(user=user).adelete() @@ -1411,15 +1439,18 @@ class EntryAdapters: @staticmethod def does_entry_exist(user: KhojUser, hashed_value: str) -> bool: + check_valid_user(user) return Entry.objects.filter(user=user, hashed_value=hashed_value).exists() @staticmethod def delete_entry_by_file(user: KhojUser, file_path: str): + check_valid_user(user) deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete() return deleted_count @staticmethod def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None): + check_valid_user(user) queryset = Entry.objects.filter(user=user) if file_type is not None: @@ -1432,6 +1463,7 @@ class EntryAdapters: @staticmethod def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): + check_valid_user(user) deleted_count = 0 queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) while queryset.exists(): @@ -1443,6 +1475,7 @@ class EntryAdapters: @staticmethod async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): + check_valid_user(user) deleted_count = 0 queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) while await queryset.aexists(): @@ -1454,10 +1487,12 @@ class EntryAdapters: @staticmethod def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): + check_valid_user(user) return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) @staticmethod def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]): + check_valid_user(user) Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete() @staticmethod @@ -1469,6 +1504,7 @@ class EntryAdapters: @staticmethod def user_has_entries(user: KhojUser): + check_valid_user(user) return Entry.objects.filter(user=user).exists() @staticmethod @@ -1477,6 +1513,7 @@ class EntryAdapters: @staticmethod async def auser_has_entries(user: KhojUser): + check_valid_user(user) return await Entry.objects.filter(user=user).aexists() @staticmethod @@ -1487,10 +1524,12 @@ class EntryAdapters: @staticmethod async def adelete_entry_by_file(user: KhojUser, file_path: str): + check_valid_user(user) return await Entry.objects.filter(user=user, file_path=file_path).adelete() @staticmethod async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000): + check_valid_user(user) deleted_count = 0 for i in range(0, len(filenames), batch_size): batch = filenames[i : i + batch_size] @@ -1509,6 +1548,7 @@ class EntryAdapters: @staticmethod def get_all_filenames_by_source(user: KhojUser, file_source: str): + check_valid_user(user) return ( Entry.objects.filter(user=user, file_source=file_source) .distinct("file_path") @@ -1517,6 +1557,7 @@ class EntryAdapters: @staticmethod def get_size_of_indexed_data_in_mb(user: KhojUser): + check_valid_user(user) entries = Entry.objects.filter(user=user).iterator() total_size = sum(sys.getsizeof(entry.compiled) for entry in entries) return total_size / 1024 / 1024 @@ -1536,6 +1577,9 @@ class EntryAdapters: if agent != None: owner_filter |= Q(agent=agent) + if owner_filter == Q(): + return Entry.objects.none() + if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0: return Entry.objects.filter(owner_filter) @@ -1611,10 +1655,12 @@ class EntryAdapters: @staticmethod def get_unique_file_types(user: KhojUser): + check_valid_user(user) return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct() @staticmethod def get_unique_file_sources(user: KhojUser): + check_valid_user(user) return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all() diff --git a/src/khoj/processor/content/docx/docx_to_entries.py b/src/khoj/processor/content/docx/docx_to_entries.py index 19d9ba13..35c634f7 100644 --- a/src/khoj/processor/content/docx/docx_to_entries.py +++ b/src/khoj/processor/content/docx/docx_to_entries.py @@ -18,7 +18,7 @@ class DocxToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config deletion_file_names = set([file for file in files if files[file] == b""]) files_to_process = set(files) - deletion_file_names @@ -35,13 +35,13 @@ class DocxToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.DOCX, DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, - user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/github/github_to_entries.py b/src/khoj/processor/content/github/github_to_entries.py index 1f3dea00..2381bea8 100644 --- a/src/khoj/processor/content/github/github_to_entries.py +++ b/src/khoj/processor/content/github/github_to_entries.py @@ -48,7 +48,7 @@ class GithubToEntries(TextToEntries): else: return - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: if self.config.pat_token is None or self.config.pat_token == "": logger.error(f"Github PAT token is not set. Skipping github content") raise ValueError("Github PAT token is not set. Skipping github content") @@ -101,12 +101,12 @@ class GithubToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.GITHUB, DbEntry.EntrySource.GITHUB, key="compiled", logger=logger, - user=user, ) return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/processor/content/images/image_to_entries.py b/src/khoj/processor/content/images/image_to_entries.py index 87b9a009..134cca52 100644 --- a/src/khoj/processor/content/images/image_to_entries.py +++ b/src/khoj/processor/content/images/image_to_entries.py @@ -18,7 +18,7 @@ class ImageToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config deletion_file_names = set([file for file in files if files[file] == b""]) files_to_process = set(files) - deletion_file_names @@ -35,13 +35,13 @@ class ImageToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.IMAGE, DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, - user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/markdown/markdown_to_entries.py b/src/khoj/processor/content/markdown/markdown_to_entries.py index fdb0c549..c4ee03ef 100644 --- a/src/khoj/processor/content/markdown/markdown_to_entries.py +++ b/src/khoj/processor/content/markdown/markdown_to_entries.py @@ -19,7 +19,7 @@ class MarkdownToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config deletion_file_names = set([file for file in files if files[file] == ""]) files_to_process = set(files) - deletion_file_names @@ -37,13 +37,13 @@ class MarkdownToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.MARKDOWN, DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, - user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/notion/notion_to_entries.py b/src/khoj/processor/content/notion/notion_to_entries.py index fc6e296f..1e1ab4d3 100644 --- a/src/khoj/processor/content/notion/notion_to_entries.py +++ b/src/khoj/processor/content/notion/notion_to_entries.py @@ -79,7 +79,7 @@ class NotionToEntries(TextToEntries): self.body_params = {"page_size": 100} - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: current_entries = [] # Get all pages @@ -248,12 +248,12 @@ class NotionToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.NOTION, DbEntry.EntrySource.NOTION, key="compiled", logger=logger, - user=user, ) return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/processor/content/org_mode/org_to_entries.py b/src/khoj/processor/content/org_mode/org_to_entries.py index 1272da11..cfc17cc0 100644 --- a/src/khoj/processor/content/org_mode/org_to_entries.py +++ b/src/khoj/processor/content/org_mode/org_to_entries.py @@ -20,7 +20,7 @@ class OrgToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: deletion_file_names = set([file for file in files if files[file] == ""]) files_to_process = set(files) - deletion_file_names files = {file: files[file] for file in files_to_process} @@ -36,13 +36,13 @@ class OrgToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.ORG, DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, - user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/pdf/pdf_to_entries.py b/src/khoj/processor/content/pdf/pdf_to_entries.py index f1ac5104..7d2bd384 100644 --- a/src/khoj/processor/content/pdf/pdf_to_entries.py +++ b/src/khoj/processor/content/pdf/pdf_to_entries.py @@ -19,7 +19,7 @@ class PdfToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config deletion_file_names = set([file for file in files if files[file] == b""]) files_to_process = set(files) - deletion_file_names @@ -36,13 +36,13 @@ class PdfToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.PDF, DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, - user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/plaintext/plaintext_to_entries.py b/src/khoj/processor/content/plaintext/plaintext_to_entries.py index 483e752f..64470c08 100644 --- a/src/khoj/processor/content/plaintext/plaintext_to_entries.py +++ b/src/khoj/processor/content/plaintext/plaintext_to_entries.py @@ -20,7 +20,7 @@ class PlaintextToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: deletion_file_names = set([file for file in files if files[file] == ""]) files_to_process = set(files) - deletion_file_names files = {file: files[file] for file in files_to_process} @@ -36,13 +36,13 @@ class PlaintextToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.PLAINTEXT, DbEntry.EntrySource.COMPUTER, key="compiled", logger=logger, deletion_filenames=deletion_file_names, - user=user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py index 181eb199..f013b28c 100644 --- a/src/khoj/processor/content/text_to_entries.py +++ b/src/khoj/processor/content/text_to_entries.py @@ -31,7 +31,7 @@ class TextToEntries(ABC): self.date_filter = DateFilter() @abstractmethod - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: ... @staticmethod @@ -114,13 +114,13 @@ class TextToEntries(ABC): def update_embeddings( self, + user: KhojUser, current_entries: List[Entry], file_type: str, file_source: str, key="compiled", logger: logging.Logger = None, deletion_filenames: Set[str] = None, - user: KhojUser = None, regenerate: bool = False, file_to_text_map: dict[str, str] = None, ): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index f66fbce8..fc7dfe27 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -212,7 +212,7 @@ def update( logger.warning(error_msg) raise HTTPException(status_code=500, detail=error_msg) try: - initialize_content(regenerate=force, search_type=t, user=user) + initialize_content(user=user, regenerate=force, search_type=t) except Exception as e: error_msg = f"🚨 Failed to update server via API: {e}" logger.error(error_msg, exc_info=True) diff --git a/src/khoj/routers/api_content.py b/src/khoj/routers/api_content.py index 40a1fb78..9ac0db47 100644 --- a/src/khoj/routers/api_content.py +++ b/src/khoj/routers/api_content.py @@ -239,7 +239,7 @@ async def set_content_notion( if updated_config.token: # Trigger an async job to configure_content. Let it run without blocking the response. - background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user) + background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion) update_telemetry_state( request=request, @@ -512,10 +512,10 @@ async def indexer( success = await loop.run_in_executor( None, configure_content, + user, indexer_input.model_dump(), regenerate, t, - user, ) if not success: raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index") diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 2d0bbe29..3a2cb5cf 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -703,7 +703,7 @@ async def generate_summary_from_files( if await EntryAdapters.aagent_has_entries(agent): file_names = await EntryAdapters.aget_agent_entry_filepaths(agent) if len(file_names) > 0: - file_objects = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent) + file_objects = await FileObjectAdapters.aget_file_objects_by_name(None, file_names.pop(), agent) if (file_objects and len(file_objects) == 0 and not query_files) or (not file_objects and not query_files): response_log = "Sorry, I couldn't find anything to summarize." @@ -1975,10 +1975,10 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) def configure_content( + user: KhojUser, files: Optional[dict[str, dict[str, str]]], regenerate: bool = False, t: Optional[state.SearchType] = state.SearchType.All, - user: KhojUser = None, ) -> bool: success = True if t == None: diff --git a/src/khoj/routers/notion.py b/src/khoj/routers/notion.py index 7d5ed25d..acfd1e2e 100644 --- a/src/khoj/routers/notion.py +++ b/src/khoj/routers/notion.py @@ -80,6 +80,6 @@ async def notion_auth_callback(request: Request, background_tasks: BackgroundTas notion_redirect = str(request.app.url_path_for("config_page")) # Trigger an async job to configure_content. Let it run without blocking the response. - background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user) + background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion) return RedirectResponse(notion_redirect) diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index eed72b51..6d7667e5 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -208,7 +208,7 @@ def setup( text_to_entries: Type[TextToEntries], files: dict[str, str], regenerate: bool, - user: KhojUser = None, + user: KhojUser, config=None, ) -> Tuple[int, int]: if config: diff --git a/src/khoj/utils/fs_syncer.py b/src/khoj/utils/fs_syncer.py index 475504f1..67e91bc9 100644 --- a/src/khoj/utils/fs_syncer.py +++ b/src/khoj/utils/fs_syncer.py @@ -8,6 +8,7 @@ from bs4 import BeautifulSoup from magika import Magika from khoj.database.models import ( + KhojUser, LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, @@ -21,7 +22,7 @@ logger = logging.getLogger(__name__) magika = Magika() -def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict: +def collect_files(user: KhojUser, search_type: Optional[SearchType] = SearchType.All) -> dict: files: dict[str, dict] = {"docx": {}, "image": {}} if search_type == SearchType.All or search_type == SearchType.Org: diff --git a/tests/conftest.py b/tests/conftest.py index 54b4db86..b91af758 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -304,7 +304,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa # Index Markdown Content for Search all_files = fs_syncer.collect_files(user=user) - success = configure_content(all_files, user=user) + configure_content(user, all_files) # Initialize Processor from Config if os.getenv("OPENAI_API_KEY"): @@ -381,7 +381,7 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser): ) all_files = fs_syncer.collect_files(user=default_user2) - configure_content(all_files, user=default_user2) + configure_content(default_user2, all_files) # Initialize Processor from Config ChatModelOptionsFactory( @@ -432,7 +432,7 @@ def pdf_configured_user1(default_user: KhojUser): ) # Index Markdown Content for Search all_files = fs_syncer.collect_files(user=default_user) - success = configure_content(all_files, user=default_user) + configure_content(default_user, all_files) @pytest.fixture(scope="function") diff --git a/tests/test_client.py b/tests/test_client.py index b8284e4b..f5ed320f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -253,11 +253,11 @@ def test_regenerate_with_github_fails_without_pat(client): # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_get_configured_types_via_api(client, sample_org_data): +def test_get_configured_types_via_api(client, sample_org_data, default_user3: KhojUser): # Act - text_search.setup(OrgToEntries, sample_org_data, regenerate=False) + text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user3) - enabled_types = EntryAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True) + enabled_types = EntryAdapters.get_unique_file_types(user=default_user3).all().values_list("file_type", flat=True) # Assert assert list(enabled_types) == ["org"] From 10bca6fa8f2308d0570de8732d61f20f95235ac8 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sun, 10 Nov 2024 17:49:55 -0800 Subject: [PATCH 2/3] Convert required user param check into decorator. Use with more adapters --- src/khoj/database/adapters/__init__.py | 159 ++++++++++++++++++------- 1 file changed, 113 insertions(+), 46 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index ad149d58..4bb5c6c1 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -8,7 +8,17 @@ import secrets import sys from datetime import date, datetime, timedelta, timezone from enum import Enum -from typing import Callable, Iterable, List, Optional, Type +from functools import wraps +from typing import ( + Any, + Callable, + Coroutine, + Iterable, + List, + Optional, + ParamSpec, + TypeVar, +) import cron_descriptor from apscheduler.job import Job @@ -80,6 +90,45 @@ class SubscriptionState(Enum): INVALID = "invalid" +P = ParamSpec("P") +T = TypeVar("T") + + +def require_valid_user(func: Callable[P, T]) -> Callable[P, T]: + @wraps(func) + def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + # Extract user from args/kwargs + user = next((arg for arg in args if isinstance(arg, KhojUser)), None) + if not user: + user = next((val for val in kwargs.values() if isinstance(val, KhojUser)), None) + + # Throw error if user is not found + if not user: + raise ValueError("Khoj user argument required but not provided.") + + return func(*args, **kwargs) + + return sync_wrapper + + +def arequire_valid_user(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Coroutine[Any, Any, T]]: + @wraps(func) + async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + # Extract user from args/kwargs + user = next((arg for arg in args if isinstance(arg, KhojUser)), None) + if not user: + user = next((v for v in kwargs.values() if isinstance(v, KhojUser)), None) + + # Throw error if user is not found + if not user: + raise ValueError("Khoj user argument required but not provided.") + + return await func(*args, **kwargs) + + return async_wrapper + + +@arequire_valid_user async def set_notion_config(token: str, user: KhojUser): notion_config = await NotionConfig.objects.filter(user=user).afirst() if not notion_config: @@ -90,6 +139,7 @@ async def set_notion_config(token: str, user: KhojUser): return notion_config +@require_valid_user def create_khoj_token(user: KhojUser, name=None): "Create Khoj API key for user" token = f"kk-{secrets.token_urlsafe(32)}" @@ -97,6 +147,7 @@ def create_khoj_token(user: KhojUser, name=None): return KhojApiUser.objects.create(token=token, user=user, name=name) +@arequire_valid_user async def acreate_khoj_token(user: KhojUser, name=None): "Create Khoj API key for user" token = f"kk-{secrets.token_urlsafe(32)}" @@ -104,11 +155,13 @@ async def acreate_khoj_token(user: KhojUser, name=None): return await KhojApiUser.objects.acreate(token=token, user=user, name=name) +@require_valid_user def get_khoj_tokens(user: KhojUser): "Get all Khoj API keys for user" return list(KhojApiUser.objects.filter(user=user)) +@arequire_valid_user async def delete_khoj_token(user: KhojUser, token: str): "Delete Khoj API Key for user" await KhojApiUser.objects.filter(token=token, user=user).adelete() @@ -132,6 +185,7 @@ async def aget_or_create_user_by_phone_number(phone_number: str) -> tuple[KhojUs return user, is_new +@arequire_valid_user async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser: if is_none_or_empty(phone_number): return None @@ -155,6 +209,7 @@ async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser: return user +@arequire_valid_user async def aremove_phone_number(user: KhojUser) -> KhojUser: user.phone_number = None user.verified_phone_number = False @@ -192,6 +247,7 @@ async def aget_or_create_user_by_email(email: str) -> tuple[KhojUser, bool]: return user, is_new +@arequire_valid_user async def astart_trial_subscription(user: KhojUser) -> Subscription: subscription = await Subscription.objects.filter(user=user).afirst() if not subscription: @@ -246,6 +302,7 @@ async def create_user_by_google_token(token: dict) -> KhojUser: return user +@require_valid_user def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser: user.first_name = first_name user.last_name = last_name @@ -253,6 +310,7 @@ def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser: return user +@require_valid_user def get_user_name(user: KhojUser): full_name = user.get_full_name() if not is_none_or_empty(full_name): @@ -264,6 +322,7 @@ def get_user_name(user: KhojUser): return None +@require_valid_user def get_user_photo(user: KhojUser): google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first() if google_profile: @@ -327,6 +386,7 @@ def get_user_subscription_state(email: str) -> str: return subscription_to_state(user_subscription) +@arequire_valid_user async def aget_user_subscription_state(user: KhojUser) -> str: """Get subscription state of user Valid state transitions: trial -> subscribed <-> unsubscribed OR expired @@ -335,6 +395,7 @@ async def aget_user_subscription_state(user: KhojUser) -> str: return await sync_to_async(subscription_to_state)(user_subscription) +@arequire_valid_user async def ais_user_subscribed(user: KhojUser) -> bool: """ Get whether the user is subscribed @@ -351,6 +412,7 @@ async def ais_user_subscribed(user: KhojUser) -> bool: return subscribed +@require_valid_user def is_user_subscribed(user: KhojUser) -> bool: """ Get whether the user is subscribed @@ -416,16 +478,13 @@ def get_all_users() -> BaseManager[KhojUser]: return KhojUser.objects.all() -def check_valid_user(user: KhojUser | None): - if not user: - raise ValueError("User not found") - - +@require_valid_user def get_user_github_config(user: KhojUser): config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() return config +@require_valid_user def get_user_notion_config(user: KhojUser): config = NotionConfig.objects.filter(user=user).first() return config @@ -435,6 +494,7 @@ def delete_user_requests(window: timedelta = timedelta(days=1)): return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete() +@arequire_valid_user async def aget_user_name(user: KhojUser): full_name = user.get_full_name() if not is_none_or_empty(full_name): @@ -458,6 +518,7 @@ async def set_text_content_config(user: KhojUser, object: Type[models.Model], up ) +@arequire_valid_user async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): config = await GithubConfig.objects.filter(user=user).afirst() @@ -592,8 +653,11 @@ class AgentAdapters: ) @staticmethod + @arequire_valid_user async def adelete_agent_by_slug(agent_slug: str, user: KhojUser): agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user) + if agent.creator != user: + return False async for entry in Entry.objects.filter(agent=agent).aiterator(): await entry.adelete() @@ -717,6 +781,7 @@ class AgentAdapters: return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst() @staticmethod + @arequire_valid_user async def aupdate_agent( user: KhojUser, name: str, @@ -817,10 +882,10 @@ class ConversationAdapters: ) @staticmethod + @require_valid_user def get_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None ) -> Optional[Conversation]: - check_valid_user(user) if conversation_id: conversation = ( Conversation.objects.filter(user=user, client=client_application, id=conversation_id) @@ -836,8 +901,8 @@ class ConversationAdapters: return conversation @staticmethod + @require_valid_user def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None): - check_valid_user(user) return ( Conversation.objects.filter(user=user, client=client_application) .prefetch_related("agent") @@ -845,10 +910,10 @@ class ConversationAdapters: ) @staticmethod + @arequire_valid_user async def aset_conversation_title( user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str ): - check_valid_user(user) conversation = await Conversation.objects.filter( user=user, client=client_application, id=conversation_id ).afirst() @@ -863,10 +928,10 @@ class ConversationAdapters: return Conversation.objects.filter(id=conversation_id).first() @staticmethod + @arequire_valid_user async def acreate_conversation_session( user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None ): - check_valid_user(user) if agent_slug: agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user) if agent is None: @@ -880,10 +945,10 @@ class ConversationAdapters: ) @staticmethod + @require_valid_user def create_conversation_session( user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None ): - check_valid_user(user) if agent_slug: agent = AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user) if agent is None: @@ -893,6 +958,7 @@ class ConversationAdapters: return Conversation.objects.create(user=user, client=client_application, agent=agent, title=title) @staticmethod + @arequire_valid_user async def aget_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, @@ -900,7 +966,6 @@ class ConversationAdapters: title: str = None, create_new: bool = False, ) -> Optional[Conversation]: - check_valid_user(user) if create_new: return await ConversationAdapters.acreate_conversation_session(user, client_application) @@ -918,17 +983,17 @@ class ConversationAdapters: ) @staticmethod + @arequire_valid_user async def adelete_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None ): - check_valid_user(user) if conversation_id: return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).adelete() return await Conversation.objects.filter(user=user, client=client_application).adelete() @staticmethod + @require_valid_user def has_any_conversation_config(user: KhojUser): - check_valid_user(user) return ChatModelOptions.objects.filter(user=user).exists() @staticmethod @@ -964,17 +1029,19 @@ class ConversationAdapters: return OpenAIProcessorConversationConfig.objects.filter().exists() @staticmethod + @arequire_valid_user async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int): config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst() - if not config or user is None: + if not config: return None new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) return new_config @staticmethod + @arequire_valid_user async def aset_user_voice_model(user: KhojUser, model_id: str): config = await VoiceModelOption.objects.filter(model_id=model_id).afirst() - if not config or user is None: + if not config: return None new_config = await UserVoiceModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) return new_config @@ -1156,10 +1223,10 @@ class ConversationAdapters: return enabled_scrapers @staticmethod + @require_valid_user def create_conversation_from_public_conversation( user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication ): - check_valid_user(user) scrubbed_title = public_conversation.title if public_conversation.title else public_conversation.slug if scrubbed_title: scrubbed_title = scrubbed_title.replace("-", " ") @@ -1173,6 +1240,7 @@ class ConversationAdapters: ) @staticmethod + @require_valid_user def save_conversation( user: KhojUser, conversation_log: dict, @@ -1180,7 +1248,6 @@ class ConversationAdapters: conversation_id: str = None, user_message: str = None, ): - check_valid_user(user) slug = user_message.strip()[:200] if user_message else None if conversation_id: conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first() @@ -1223,8 +1290,8 @@ class ConversationAdapters: return await SpeechToTextModelOptions.objects.filter().afirst() @staticmethod + @arequire_valid_user async def aget_conversation_starters(user: KhojUser, max_results=3): - check_valid_user(user) all_questions = [] if await ReflectiveQuestion.objects.filter(user=user).aexists(): all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=user).values_list)( @@ -1353,8 +1420,8 @@ class ConversationAdapters: return conversation.file_filters @staticmethod + @require_valid_user def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str): - check_valid_user(user) conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id) if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"): return False @@ -1372,28 +1439,28 @@ class FileObjectAdapters: file_object.save() @staticmethod + @require_valid_user def create_file_object(user: KhojUser, file_name: str, raw_text: str): - check_valid_user(user) return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text) @staticmethod + @require_valid_user def get_file_object_by_name(user: KhojUser, file_name: str): - check_valid_user(user) return FileObject.objects.filter(user=user, file_name=file_name).first() @staticmethod + @require_valid_user def get_all_file_objects(user: KhojUser): - check_valid_user(user) return FileObject.objects.filter(user=user).all() @staticmethod + @require_valid_user def delete_file_object_by_name(user: KhojUser, file_name: str): - check_valid_user(user) return FileObject.objects.filter(user=user, file_name=file_name).delete() @staticmethod + @require_valid_user def delete_all_file_objects(user: KhojUser): - check_valid_user(user) return FileObject.objects.filter(user=user).delete() @staticmethod @@ -1402,33 +1469,33 @@ class FileObjectAdapters: await file_object.asave() @staticmethod + @arequire_valid_user async def acreate_file_object(user: KhojUser, file_name: str, raw_text: str): - check_valid_user(user) return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text) @staticmethod + @arequire_valid_user async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): - check_valid_user(user) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent)) @staticmethod + @arequire_valid_user async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]): - check_valid_user(user) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names)) @staticmethod + @arequire_valid_user async def aget_all_file_objects(user: KhojUser): - check_valid_user(user) return await sync_to_async(list)(FileObject.objects.filter(user=user)) @staticmethod + @arequire_valid_user async def adelete_file_object_by_name(user: KhojUser, file_name: str): - check_valid_user(user) return await FileObject.objects.filter(user=user, file_name=file_name).adelete() @staticmethod + @arequire_valid_user async def adelete_all_file_objects(user: KhojUser): - check_valid_user(user) return await FileObject.objects.filter(user=user).adelete() @@ -1438,19 +1505,19 @@ class EntryAdapters: date_filter = DateFilter() @staticmethod + @require_valid_user def does_entry_exist(user: KhojUser, hashed_value: str) -> bool: - check_valid_user(user) return Entry.objects.filter(user=user, hashed_value=hashed_value).exists() @staticmethod + @require_valid_user def delete_entry_by_file(user: KhojUser, file_path: str): - check_valid_user(user) deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete() return deleted_count @staticmethod + @require_valid_user def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None): - check_valid_user(user) queryset = Entry.objects.filter(user=user) if file_type is not None: @@ -1462,8 +1529,8 @@ class EntryAdapters: return queryset @staticmethod + @require_valid_user def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): - check_valid_user(user) deleted_count = 0 queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) while queryset.exists(): @@ -1474,8 +1541,8 @@ class EntryAdapters: return deleted_count @staticmethod + @arequire_valid_user async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): - check_valid_user(user) deleted_count = 0 queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) while await queryset.aexists(): @@ -1486,13 +1553,13 @@ class EntryAdapters: return deleted_count @staticmethod + @require_valid_user def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): - check_valid_user(user) return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) @staticmethod + @require_valid_user def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]): - check_valid_user(user) Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete() @staticmethod @@ -1503,8 +1570,8 @@ class EntryAdapters: ) @staticmethod + @require_valid_user def user_has_entries(user: KhojUser): - check_valid_user(user) return Entry.objects.filter(user=user).exists() @staticmethod @@ -1512,8 +1579,8 @@ class EntryAdapters: return Entry.objects.filter(agent=agent).exists() @staticmethod + @arequire_valid_user async def auser_has_entries(user: KhojUser): - check_valid_user(user) return await Entry.objects.filter(user=user).aexists() @staticmethod @@ -1523,13 +1590,13 @@ class EntryAdapters: return await Entry.objects.filter(agent=agent).aexists() @staticmethod + @arequire_valid_user async def adelete_entry_by_file(user: KhojUser, file_path: str): - check_valid_user(user) return await Entry.objects.filter(user=user, file_path=file_path).adelete() @staticmethod + @arequire_valid_user async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000): - check_valid_user(user) deleted_count = 0 for i in range(0, len(filenames), batch_size): batch = filenames[i : i + batch_size] @@ -1547,8 +1614,8 @@ class EntryAdapters: ) @staticmethod + @require_valid_user def get_all_filenames_by_source(user: KhojUser, file_source: str): - check_valid_user(user) return ( Entry.objects.filter(user=user, file_source=file_source) .distinct("file_path") @@ -1556,8 +1623,8 @@ class EntryAdapters: ) @staticmethod + @require_valid_user def get_size_of_indexed_data_in_mb(user: KhojUser): - check_valid_user(user) entries = Entry.objects.filter(user=user).iterator() total_size = sum(sys.getsizeof(entry.compiled) for entry in entries) return total_size / 1024 / 1024 @@ -1654,13 +1721,13 @@ class EntryAdapters: return relevant_entries[:max_results] @staticmethod + @require_valid_user def get_unique_file_types(user: KhojUser): - check_valid_user(user) return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct() @staticmethod + @require_valid_user def get_unique_file_sources(user: KhojUser): - check_valid_user(user) return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all() From 536fe994be7ef5b6f467b6647481503102b4cb2c Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sun, 10 Nov 2024 17:51:39 -0800 Subject: [PATCH 3/3] Remove unused db adapter methods, like for fact checker data store --- src/khoj/database/adapters/__init__.py | 27 -------------------------- 1 file changed, 27 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 4bb5c6c1..8538e217 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -24,7 +24,6 @@ import cron_descriptor from apscheduler.job import Job from asgiref.sync import sync_to_async from django.contrib.sessions.backends.db import SessionStore -from django.db import models from django.db.models import Prefetch, Q from django.db.models.manager import BaseManager from django.db.utils import IntegrityError @@ -38,7 +37,6 @@ from khoj.database.models import ( ChatModelOptions, ClientApplication, Conversation, - DataStore, Entry, FileObject, GithubConfig, @@ -506,18 +504,6 @@ async def aget_user_name(user: KhojUser): return None -async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config): - deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None - deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None - await object.objects.filter(user=user).adelete() - await object.objects.acreate( - input_files=deduped_files, - input_filter=deduped_filters, - index_heading_entries=updated_config.index_heading_entries, - user=user, - ) - - @arequire_valid_user async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): config = await GithubConfig.objects.filter(user=user).afirst() @@ -857,19 +843,6 @@ class PublicConversationAdapters: return f"/share/chat/{public_conversation.slug}/" -class DataStoreAdapters: - @staticmethod - async def astore_data(data: dict, key: str, user: KhojUser, private: bool = True): - if await DataStore.objects.filter(key=key).aexists(): - return key - await DataStore.objects.acreate(value=data, key=key, owner=user, private=private) - return key - - @staticmethod - async def aretrieve_public_data(key: str): - return await DataStore.objects.filter(key=key, private=False).afirst() - - class ConversationAdapters: @staticmethod def make_public_conversation_copy(conversation: Conversation):