Convert required user param check into decorator. Use with more adapters

This commit is contained in:
Debanjum 2024-11-10 17:49:55 -08:00
parent ff5c10c221
commit 10bca6fa8f

View file

@ -8,7 +8,17 @@ import secrets
import sys import sys
from datetime import date, datetime, timedelta, timezone from datetime import date, datetime, timedelta, timezone
from enum import Enum from enum import Enum
from typing import Callable, Iterable, List, Optional, Type from functools import wraps
from typing import (
Any,
Callable,
Coroutine,
Iterable,
List,
Optional,
ParamSpec,
TypeVar,
)
import cron_descriptor import cron_descriptor
from apscheduler.job import Job from apscheduler.job import Job
@ -80,6 +90,45 @@ class SubscriptionState(Enum):
INVALID = "invalid" INVALID = "invalid"
P = ParamSpec("P")
T = TypeVar("T")
def require_valid_user(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# Extract user from args/kwargs
user = next((arg for arg in args if isinstance(arg, KhojUser)), None)
if not user:
user = next((val for val in kwargs.values() if isinstance(val, KhojUser)), None)
# Throw error if user is not found
if not user:
raise ValueError("Khoj user argument required but not provided.")
return func(*args, **kwargs)
return sync_wrapper
def arequire_valid_user(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Coroutine[Any, Any, T]]:
@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# Extract user from args/kwargs
user = next((arg for arg in args if isinstance(arg, KhojUser)), None)
if not user:
user = next((v for v in kwargs.values() if isinstance(v, KhojUser)), None)
# Throw error if user is not found
if not user:
raise ValueError("Khoj user argument required but not provided.")
return await func(*args, **kwargs)
return async_wrapper
@arequire_valid_user
async def set_notion_config(token: str, user: KhojUser): async def set_notion_config(token: str, user: KhojUser):
notion_config = await NotionConfig.objects.filter(user=user).afirst() notion_config = await NotionConfig.objects.filter(user=user).afirst()
if not notion_config: if not notion_config:
@ -90,6 +139,7 @@ async def set_notion_config(token: str, user: KhojUser):
return notion_config return notion_config
@require_valid_user
def create_khoj_token(user: KhojUser, name=None): def create_khoj_token(user: KhojUser, name=None):
"Create Khoj API key for user" "Create Khoj API key for user"
token = f"kk-{secrets.token_urlsafe(32)}" token = f"kk-{secrets.token_urlsafe(32)}"
@ -97,6 +147,7 @@ def create_khoj_token(user: KhojUser, name=None):
return KhojApiUser.objects.create(token=token, user=user, name=name) return KhojApiUser.objects.create(token=token, user=user, name=name)
@arequire_valid_user
async def acreate_khoj_token(user: KhojUser, name=None): async def acreate_khoj_token(user: KhojUser, name=None):
"Create Khoj API key for user" "Create Khoj API key for user"
token = f"kk-{secrets.token_urlsafe(32)}" token = f"kk-{secrets.token_urlsafe(32)}"
@ -104,11 +155,13 @@ async def acreate_khoj_token(user: KhojUser, name=None):
return await KhojApiUser.objects.acreate(token=token, user=user, name=name) return await KhojApiUser.objects.acreate(token=token, user=user, name=name)
@require_valid_user
def get_khoj_tokens(user: KhojUser): def get_khoj_tokens(user: KhojUser):
"Get all Khoj API keys for user" "Get all Khoj API keys for user"
return list(KhojApiUser.objects.filter(user=user)) return list(KhojApiUser.objects.filter(user=user))
@arequire_valid_user
async def delete_khoj_token(user: KhojUser, token: str): async def delete_khoj_token(user: KhojUser, token: str):
"Delete Khoj API Key for user" "Delete Khoj API Key for user"
await KhojApiUser.objects.filter(token=token, user=user).adelete() await KhojApiUser.objects.filter(token=token, user=user).adelete()
@ -132,6 +185,7 @@ async def aget_or_create_user_by_phone_number(phone_number: str) -> tuple[KhojUs
return user, is_new return user, is_new
@arequire_valid_user
async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser: async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number): if is_none_or_empty(phone_number):
return None return None
@ -155,6 +209,7 @@ async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
return user return user
@arequire_valid_user
async def aremove_phone_number(user: KhojUser) -> KhojUser: async def aremove_phone_number(user: KhojUser) -> KhojUser:
user.phone_number = None user.phone_number = None
user.verified_phone_number = False user.verified_phone_number = False
@ -192,6 +247,7 @@ async def aget_or_create_user_by_email(email: str) -> tuple[KhojUser, bool]:
return user, is_new return user, is_new
@arequire_valid_user
async def astart_trial_subscription(user: KhojUser) -> Subscription: async def astart_trial_subscription(user: KhojUser) -> Subscription:
subscription = await Subscription.objects.filter(user=user).afirst() subscription = await Subscription.objects.filter(user=user).afirst()
if not subscription: if not subscription:
@ -246,6 +302,7 @@ async def create_user_by_google_token(token: dict) -> KhojUser:
return user return user
@require_valid_user
def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser: def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
user.first_name = first_name user.first_name = first_name
user.last_name = last_name user.last_name = last_name
@ -253,6 +310,7 @@ def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
return user return user
@require_valid_user
def get_user_name(user: KhojUser): def get_user_name(user: KhojUser):
full_name = user.get_full_name() full_name = user.get_full_name()
if not is_none_or_empty(full_name): if not is_none_or_empty(full_name):
@ -264,6 +322,7 @@ def get_user_name(user: KhojUser):
return None return None
@require_valid_user
def get_user_photo(user: KhojUser): def get_user_photo(user: KhojUser):
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first() google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
if google_profile: if google_profile:
@ -327,6 +386,7 @@ def get_user_subscription_state(email: str) -> str:
return subscription_to_state(user_subscription) return subscription_to_state(user_subscription)
@arequire_valid_user
async def aget_user_subscription_state(user: KhojUser) -> str: async def aget_user_subscription_state(user: KhojUser) -> str:
"""Get subscription state of user """Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
@ -335,6 +395,7 @@ async def aget_user_subscription_state(user: KhojUser) -> str:
return await sync_to_async(subscription_to_state)(user_subscription) return await sync_to_async(subscription_to_state)(user_subscription)
@arequire_valid_user
async def ais_user_subscribed(user: KhojUser) -> bool: async def ais_user_subscribed(user: KhojUser) -> bool:
""" """
Get whether the user is subscribed Get whether the user is subscribed
@ -351,6 +412,7 @@ async def ais_user_subscribed(user: KhojUser) -> bool:
return subscribed return subscribed
@require_valid_user
def is_user_subscribed(user: KhojUser) -> bool: def is_user_subscribed(user: KhojUser) -> bool:
""" """
Get whether the user is subscribed Get whether the user is subscribed
@ -416,16 +478,13 @@ def get_all_users() -> BaseManager[KhojUser]:
return KhojUser.objects.all() return KhojUser.objects.all()
def check_valid_user(user: KhojUser | None): @require_valid_user
if not user:
raise ValueError("User not found")
def get_user_github_config(user: KhojUser): def get_user_github_config(user: KhojUser):
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
return config return config
@require_valid_user
def get_user_notion_config(user: KhojUser): def get_user_notion_config(user: KhojUser):
config = NotionConfig.objects.filter(user=user).first() config = NotionConfig.objects.filter(user=user).first()
return config return config
@ -435,6 +494,7 @@ def delete_user_requests(window: timedelta = timedelta(days=1)):
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete() return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
@arequire_valid_user
async def aget_user_name(user: KhojUser): async def aget_user_name(user: KhojUser):
full_name = user.get_full_name() full_name = user.get_full_name()
if not is_none_or_empty(full_name): if not is_none_or_empty(full_name):
@ -458,6 +518,7 @@ async def set_text_content_config(user: KhojUser, object: Type[models.Model], up
) )
@arequire_valid_user
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
config = await GithubConfig.objects.filter(user=user).afirst() config = await GithubConfig.objects.filter(user=user).afirst()
@ -592,8 +653,11 @@ class AgentAdapters:
) )
@staticmethod @staticmethod
@arequire_valid_user
async def adelete_agent_by_slug(agent_slug: str, user: KhojUser): async def adelete_agent_by_slug(agent_slug: str, user: KhojUser):
agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user) agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
if agent.creator != user:
return False
async for entry in Entry.objects.filter(agent=agent).aiterator(): async for entry in Entry.objects.filter(agent=agent).aiterator():
await entry.adelete() await entry.adelete()
@ -717,6 +781,7 @@ class AgentAdapters:
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst() return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
@staticmethod @staticmethod
@arequire_valid_user
async def aupdate_agent( async def aupdate_agent(
user: KhojUser, user: KhojUser,
name: str, name: str,
@ -817,10 +882,10 @@ class ConversationAdapters:
) )
@staticmethod @staticmethod
@require_valid_user
def get_conversation_by_user( def get_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
) -> Optional[Conversation]: ) -> Optional[Conversation]:
check_valid_user(user)
if conversation_id: if conversation_id:
conversation = ( conversation = (
Conversation.objects.filter(user=user, client=client_application, id=conversation_id) Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
@ -836,8 +901,8 @@ class ConversationAdapters:
return conversation return conversation
@staticmethod @staticmethod
@require_valid_user
def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None): def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None):
check_valid_user(user)
return ( return (
Conversation.objects.filter(user=user, client=client_application) Conversation.objects.filter(user=user, client=client_application)
.prefetch_related("agent") .prefetch_related("agent")
@ -845,10 +910,10 @@ class ConversationAdapters:
) )
@staticmethod @staticmethod
@arequire_valid_user
async def aset_conversation_title( async def aset_conversation_title(
user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str
): ):
check_valid_user(user)
conversation = await Conversation.objects.filter( conversation = await Conversation.objects.filter(
user=user, client=client_application, id=conversation_id user=user, client=client_application, id=conversation_id
).afirst() ).afirst()
@ -863,10 +928,10 @@ class ConversationAdapters:
return Conversation.objects.filter(id=conversation_id).first() return Conversation.objects.filter(id=conversation_id).first()
@staticmethod @staticmethod
@arequire_valid_user
async def acreate_conversation_session( async def acreate_conversation_session(
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
): ):
check_valid_user(user)
if agent_slug: if agent_slug:
agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user) agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
if agent is None: if agent is None:
@ -880,10 +945,10 @@ class ConversationAdapters:
) )
@staticmethod @staticmethod
@require_valid_user
def create_conversation_session( def create_conversation_session(
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
): ):
check_valid_user(user)
if agent_slug: if agent_slug:
agent = AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user) agent = AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
if agent is None: if agent is None:
@ -893,6 +958,7 @@ class ConversationAdapters:
return Conversation.objects.create(user=user, client=client_application, agent=agent, title=title) return Conversation.objects.create(user=user, client=client_application, agent=agent, title=title)
@staticmethod @staticmethod
@arequire_valid_user
async def aget_conversation_by_user( async def aget_conversation_by_user(
user: KhojUser, user: KhojUser,
client_application: ClientApplication = None, client_application: ClientApplication = None,
@ -900,7 +966,6 @@ class ConversationAdapters:
title: str = None, title: str = None,
create_new: bool = False, create_new: bool = False,
) -> Optional[Conversation]: ) -> Optional[Conversation]:
check_valid_user(user)
if create_new: if create_new:
return await ConversationAdapters.acreate_conversation_session(user, client_application) return await ConversationAdapters.acreate_conversation_session(user, client_application)
@ -918,17 +983,17 @@ class ConversationAdapters:
) )
@staticmethod @staticmethod
@arequire_valid_user
async def adelete_conversation_by_user( async def adelete_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
): ):
check_valid_user(user)
if conversation_id: if conversation_id:
return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).adelete() return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).adelete()
return await Conversation.objects.filter(user=user, client=client_application).adelete() return await Conversation.objects.filter(user=user, client=client_application).adelete()
@staticmethod @staticmethod
@require_valid_user
def has_any_conversation_config(user: KhojUser): def has_any_conversation_config(user: KhojUser):
check_valid_user(user)
return ChatModelOptions.objects.filter(user=user).exists() return ChatModelOptions.objects.filter(user=user).exists()
@staticmethod @staticmethod
@ -964,17 +1029,19 @@ class ConversationAdapters:
return OpenAIProcessorConversationConfig.objects.filter().exists() return OpenAIProcessorConversationConfig.objects.filter().exists()
@staticmethod @staticmethod
@arequire_valid_user
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int): async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst() config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
if not config or user is None: if not config:
return None return None
new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config return new_config
@staticmethod @staticmethod
@arequire_valid_user
async def aset_user_voice_model(user: KhojUser, model_id: str): async def aset_user_voice_model(user: KhojUser, model_id: str):
config = await VoiceModelOption.objects.filter(model_id=model_id).afirst() config = await VoiceModelOption.objects.filter(model_id=model_id).afirst()
if not config or user is None: if not config:
return None return None
new_config = await UserVoiceModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) new_config = await UserVoiceModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config return new_config
@ -1156,10 +1223,10 @@ class ConversationAdapters:
return enabled_scrapers return enabled_scrapers
@staticmethod @staticmethod
@require_valid_user
def create_conversation_from_public_conversation( def create_conversation_from_public_conversation(
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication
): ):
check_valid_user(user)
scrubbed_title = public_conversation.title if public_conversation.title else public_conversation.slug scrubbed_title = public_conversation.title if public_conversation.title else public_conversation.slug
if scrubbed_title: if scrubbed_title:
scrubbed_title = scrubbed_title.replace("-", " ") scrubbed_title = scrubbed_title.replace("-", " ")
@ -1173,6 +1240,7 @@ class ConversationAdapters:
) )
@staticmethod @staticmethod
@require_valid_user
def save_conversation( def save_conversation(
user: KhojUser, user: KhojUser,
conversation_log: dict, conversation_log: dict,
@ -1180,7 +1248,6 @@ class ConversationAdapters:
conversation_id: str = None, conversation_id: str = None,
user_message: str = None, user_message: str = None,
): ):
check_valid_user(user)
slug = user_message.strip()[:200] if user_message else None slug = user_message.strip()[:200] if user_message else None
if conversation_id: if conversation_id:
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first() conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first()
@ -1223,8 +1290,8 @@ class ConversationAdapters:
return await SpeechToTextModelOptions.objects.filter().afirst() return await SpeechToTextModelOptions.objects.filter().afirst()
@staticmethod @staticmethod
@arequire_valid_user
async def aget_conversation_starters(user: KhojUser, max_results=3): async def aget_conversation_starters(user: KhojUser, max_results=3):
check_valid_user(user)
all_questions = [] all_questions = []
if await ReflectiveQuestion.objects.filter(user=user).aexists(): if await ReflectiveQuestion.objects.filter(user=user).aexists():
all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=user).values_list)( all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=user).values_list)(
@ -1353,8 +1420,8 @@ class ConversationAdapters:
return conversation.file_filters return conversation.file_filters
@staticmethod @staticmethod
@require_valid_user
def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str): def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str):
check_valid_user(user)
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id) conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"): if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"):
return False return False
@ -1372,28 +1439,28 @@ class FileObjectAdapters:
file_object.save() file_object.save()
@staticmethod @staticmethod
@require_valid_user
def create_file_object(user: KhojUser, file_name: str, raw_text: str): def create_file_object(user: KhojUser, file_name: str, raw_text: str):
check_valid_user(user)
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text) return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod @staticmethod
@require_valid_user
def get_file_object_by_name(user: KhojUser, file_name: str): def get_file_object_by_name(user: KhojUser, file_name: str):
check_valid_user(user)
return FileObject.objects.filter(user=user, file_name=file_name).first() return FileObject.objects.filter(user=user, file_name=file_name).first()
@staticmethod @staticmethod
@require_valid_user
def get_all_file_objects(user: KhojUser): def get_all_file_objects(user: KhojUser):
check_valid_user(user)
return FileObject.objects.filter(user=user).all() return FileObject.objects.filter(user=user).all()
@staticmethod @staticmethod
@require_valid_user
def delete_file_object_by_name(user: KhojUser, file_name: str): def delete_file_object_by_name(user: KhojUser, file_name: str):
check_valid_user(user)
return FileObject.objects.filter(user=user, file_name=file_name).delete() return FileObject.objects.filter(user=user, file_name=file_name).delete()
@staticmethod @staticmethod
@require_valid_user
def delete_all_file_objects(user: KhojUser): def delete_all_file_objects(user: KhojUser):
check_valid_user(user)
return FileObject.objects.filter(user=user).delete() return FileObject.objects.filter(user=user).delete()
@staticmethod @staticmethod
@ -1402,33 +1469,33 @@ class FileObjectAdapters:
await file_object.asave() await file_object.asave()
@staticmethod @staticmethod
@arequire_valid_user
async def acreate_file_object(user: KhojUser, file_name: str, raw_text: str): async def acreate_file_object(user: KhojUser, file_name: str, raw_text: str):
check_valid_user(user)
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text) return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod @staticmethod
@arequire_valid_user
async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
check_valid_user(user)
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent)) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
@staticmethod @staticmethod
@arequire_valid_user
async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]): async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]):
check_valid_user(user)
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names)) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names))
@staticmethod @staticmethod
@arequire_valid_user
async def aget_all_file_objects(user: KhojUser): async def aget_all_file_objects(user: KhojUser):
check_valid_user(user)
return await sync_to_async(list)(FileObject.objects.filter(user=user)) return await sync_to_async(list)(FileObject.objects.filter(user=user))
@staticmethod @staticmethod
@arequire_valid_user
async def adelete_file_object_by_name(user: KhojUser, file_name: str): async def adelete_file_object_by_name(user: KhojUser, file_name: str):
check_valid_user(user)
return await FileObject.objects.filter(user=user, file_name=file_name).adelete() return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
@staticmethod @staticmethod
@arequire_valid_user
async def adelete_all_file_objects(user: KhojUser): async def adelete_all_file_objects(user: KhojUser):
check_valid_user(user)
return await FileObject.objects.filter(user=user).adelete() return await FileObject.objects.filter(user=user).adelete()
@ -1438,19 +1505,19 @@ class EntryAdapters:
date_filter = DateFilter() date_filter = DateFilter()
@staticmethod @staticmethod
@require_valid_user
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool: def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
check_valid_user(user)
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists() return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
@staticmethod @staticmethod
@require_valid_user
def delete_entry_by_file(user: KhojUser, file_path: str): def delete_entry_by_file(user: KhojUser, file_path: str):
check_valid_user(user)
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete() deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
return deleted_count return deleted_count
@staticmethod @staticmethod
@require_valid_user
def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None): def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None):
check_valid_user(user)
queryset = Entry.objects.filter(user=user) queryset = Entry.objects.filter(user=user)
if file_type is not None: if file_type is not None:
@ -1462,8 +1529,8 @@ class EntryAdapters:
return queryset return queryset
@staticmethod @staticmethod
@require_valid_user
def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
check_valid_user(user)
deleted_count = 0 deleted_count = 0
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
while queryset.exists(): while queryset.exists():
@ -1474,8 +1541,8 @@ class EntryAdapters:
return deleted_count return deleted_count
@staticmethod @staticmethod
@arequire_valid_user
async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
check_valid_user(user)
deleted_count = 0 deleted_count = 0
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
while await queryset.aexists(): while await queryset.aexists():
@ -1486,13 +1553,13 @@ class EntryAdapters:
return deleted_count return deleted_count
@staticmethod @staticmethod
@require_valid_user
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
check_valid_user(user)
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@staticmethod @staticmethod
@require_valid_user
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]): def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
check_valid_user(user)
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete() Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
@staticmethod @staticmethod
@ -1503,8 +1570,8 @@ class EntryAdapters:
) )
@staticmethod @staticmethod
@require_valid_user
def user_has_entries(user: KhojUser): def user_has_entries(user: KhojUser):
check_valid_user(user)
return Entry.objects.filter(user=user).exists() return Entry.objects.filter(user=user).exists()
@staticmethod @staticmethod
@ -1512,8 +1579,8 @@ class EntryAdapters:
return Entry.objects.filter(agent=agent).exists() return Entry.objects.filter(agent=agent).exists()
@staticmethod @staticmethod
@arequire_valid_user
async def auser_has_entries(user: KhojUser): async def auser_has_entries(user: KhojUser):
check_valid_user(user)
return await Entry.objects.filter(user=user).aexists() return await Entry.objects.filter(user=user).aexists()
@staticmethod @staticmethod
@ -1523,13 +1590,13 @@ class EntryAdapters:
return await Entry.objects.filter(agent=agent).aexists() return await Entry.objects.filter(agent=agent).aexists()
@staticmethod @staticmethod
@arequire_valid_user
async def adelete_entry_by_file(user: KhojUser, file_path: str): async def adelete_entry_by_file(user: KhojUser, file_path: str):
check_valid_user(user)
return await Entry.objects.filter(user=user, file_path=file_path).adelete() return await Entry.objects.filter(user=user, file_path=file_path).adelete()
@staticmethod @staticmethod
@arequire_valid_user
async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000): async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000):
check_valid_user(user)
deleted_count = 0 deleted_count = 0
for i in range(0, len(filenames), batch_size): for i in range(0, len(filenames), batch_size):
batch = filenames[i : i + batch_size] batch = filenames[i : i + batch_size]
@ -1547,8 +1614,8 @@ class EntryAdapters:
) )
@staticmethod @staticmethod
@require_valid_user
def get_all_filenames_by_source(user: KhojUser, file_source: str): def get_all_filenames_by_source(user: KhojUser, file_source: str):
check_valid_user(user)
return ( return (
Entry.objects.filter(user=user, file_source=file_source) Entry.objects.filter(user=user, file_source=file_source)
.distinct("file_path") .distinct("file_path")
@ -1556,8 +1623,8 @@ class EntryAdapters:
) )
@staticmethod @staticmethod
@require_valid_user
def get_size_of_indexed_data_in_mb(user: KhojUser): def get_size_of_indexed_data_in_mb(user: KhojUser):
check_valid_user(user)
entries = Entry.objects.filter(user=user).iterator() entries = Entry.objects.filter(user=user).iterator()
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries) total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
return total_size / 1024 / 1024 return total_size / 1024 / 1024
@ -1654,13 +1721,13 @@ class EntryAdapters:
return relevant_entries[:max_results] return relevant_entries[:max_results]
@staticmethod @staticmethod
@require_valid_user
def get_unique_file_types(user: KhojUser): def get_unique_file_types(user: KhojUser):
check_valid_user(user)
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct() return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
@staticmethod @staticmethod
@require_valid_user
def get_unique_file_sources(user: KhojUser): def get_unique_file_sources(user: KhojUser):
check_valid_user(user)
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all() return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all()