diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 2ea9e9af..3a21a919 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1012,27 +1012,35 @@ class EntryAdapters: return deleted_count @staticmethod - def delete_all_entries_by_type(user: KhojUser, file_type: str = None): - if file_type is None: - deleted_count, _ = Entry.objects.filter(user=user).delete() - else: - deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete() + def get_entries_by_batch(user: KhojUser, batch_size: int, file_type: str = None, file_source: str = None): + queryset = Entry.objects.filter(user=user) + + if file_type is not None: + queryset = queryset.filter(file_type=file_type) + + if file_source is not None: + queryset = queryset.filter(file_source=file_source) + + while queryset.exists(): + batch_ids = list(queryset.values_list("id", flat=True)[:batch_size]) + yield Entry.objects.filter(id__in=batch_ids) + + @staticmethod + def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): + deleted_count = 0 + for batch in EntryAdapters.get_entries_by_batch(user, batch_size, file_type, file_source): + count, _ = batch.delete() + deleted_count += count return deleted_count @staticmethod - def delete_all_entries(user: KhojUser, file_source: str = None): - if file_source is None: - deleted_count, _ = Entry.objects.filter(user=user).delete() - else: - deleted_count, _ = Entry.objects.filter(user=user, file_source=file_source).delete() + async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): + deleted_count = 0 + async for batch in EntryAdapters.get_entries_by_batch(user, batch_size, file_type, file_source): + count, _ = await batch.adelete() + deleted_count += count return deleted_count - @staticmethod - async def adelete_all_entries(user: KhojUser, file_source: str = None): - if file_source is None: - return await Entry.objects.filter(user=user).adelete() - return await Entry.objects.filter(user=user, file_source=file_source).adelete() - @staticmethod def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) diff --git a/src/khoj/interface/web/content_source_computer_input.html b/src/khoj/interface/web/content_source_computer_input.html index 77816f35..77ce2287 100644 --- a/src/khoj/interface/web/content_source_computer_input.html +++ b/src/khoj/interface/web/content_source_computer_input.html @@ -12,7 +12,7 @@