diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index ad149d58..4bb5c6c1 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -8,7 +8,17 @@ import secrets import sys from datetime import date, datetime, timedelta, timezone from enum import Enum -from typing import Callable, Iterable, List, Optional, Type +from functools import wraps +from typing import ( + Any, + Callable, + Coroutine, + Iterable, + List, + Optional, + ParamSpec, + TypeVar, +) import cron_descriptor from apscheduler.job import Job @@ -80,6 +90,45 @@ class SubscriptionState(Enum): INVALID = "invalid" +P = ParamSpec("P") +T = TypeVar("T") + + +def require_valid_user(func: Callable[P, T]) -> Callable[P, T]: + @wraps(func) + def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + # Extract user from args/kwargs + user = next((arg for arg in args if isinstance(arg, KhojUser)), None) + if not user: + user = next((val for val in kwargs.values() if isinstance(val, KhojUser)), None) + + # Throw error if user is not found + if not user: + raise ValueError("Khoj user argument required but not provided.") + + return func(*args, **kwargs) + + return sync_wrapper + + +def arequire_valid_user(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Coroutine[Any, Any, T]]: + @wraps(func) + async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + # Extract user from args/kwargs + user = next((arg for arg in args if isinstance(arg, KhojUser)), None) + if not user: + user = next((v for v in kwargs.values() if isinstance(v, KhojUser)), None) + + # Throw error if user is not found + if not user: + raise ValueError("Khoj user argument required but not provided.") + + return await func(*args, **kwargs) + + return async_wrapper + + +@arequire_valid_user async def set_notion_config(token: str, user: KhojUser): notion_config = await NotionConfig.objects.filter(user=user).afirst() if not notion_config: @@ -90,6 +139,7 @@ async def set_notion_config(token: str, user: KhojUser): return notion_config +@require_valid_user def create_khoj_token(user: KhojUser, name=None): "Create Khoj API key for user" token = f"kk-{secrets.token_urlsafe(32)}" @@ -97,6 +147,7 @@ def create_khoj_token(user: KhojUser, name=None): return KhojApiUser.objects.create(token=token, user=user, name=name) +@arequire_valid_user async def acreate_khoj_token(user: KhojUser, name=None): "Create Khoj API key for user" token = f"kk-{secrets.token_urlsafe(32)}" @@ -104,11 +155,13 @@ async def acreate_khoj_token(user: KhojUser, name=None): return await KhojApiUser.objects.acreate(token=token, user=user, name=name) +@require_valid_user def get_khoj_tokens(user: KhojUser): "Get all Khoj API keys for user" return list(KhojApiUser.objects.filter(user=user)) +@arequire_valid_user async def delete_khoj_token(user: KhojUser, token: str): "Delete Khoj API Key for user" await KhojApiUser.objects.filter(token=token, user=user).adelete() @@ -132,6 +185,7 @@ async def aget_or_create_user_by_phone_number(phone_number: str) -> tuple[KhojUs return user, is_new +@arequire_valid_user async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser: if is_none_or_empty(phone_number): return None @@ -155,6 +209,7 @@ async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser: return user +@arequire_valid_user async def aremove_phone_number(user: KhojUser) -> KhojUser: user.phone_number = None user.verified_phone_number = False @@ -192,6 +247,7 @@ async def aget_or_create_user_by_email(email: str) -> tuple[KhojUser, bool]: return user, is_new +@arequire_valid_user async def astart_trial_subscription(user: KhojUser) -> Subscription: subscription = await Subscription.objects.filter(user=user).afirst() if not subscription: @@ -246,6 +302,7 @@ async def create_user_by_google_token(token: dict) -> KhojUser: return user +@require_valid_user def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser: user.first_name = first_name user.last_name = last_name @@ -253,6 +310,7 @@ def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser: return user +@require_valid_user def get_user_name(user: KhojUser): full_name = user.get_full_name() if not is_none_or_empty(full_name): @@ -264,6 +322,7 @@ def get_user_name(user: KhojUser): return None +@require_valid_user def get_user_photo(user: KhojUser): google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first() if google_profile: @@ -327,6 +386,7 @@ def get_user_subscription_state(email: str) -> str: return subscription_to_state(user_subscription) +@arequire_valid_user async def aget_user_subscription_state(user: KhojUser) -> str: """Get subscription state of user Valid state transitions: trial -> subscribed <-> unsubscribed OR expired @@ -335,6 +395,7 @@ async def aget_user_subscription_state(user: KhojUser) -> str: return await sync_to_async(subscription_to_state)(user_subscription) +@arequire_valid_user async def ais_user_subscribed(user: KhojUser) -> bool: """ Get whether the user is subscribed @@ -351,6 +412,7 @@ async def ais_user_subscribed(user: KhojUser) -> bool: return subscribed +@require_valid_user def is_user_subscribed(user: KhojUser) -> bool: """ Get whether the user is subscribed @@ -416,16 +478,13 @@ def get_all_users() -> BaseManager[KhojUser]: return KhojUser.objects.all() -def check_valid_user(user: KhojUser | None): - if not user: - raise ValueError("User not found") - - +@require_valid_user def get_user_github_config(user: KhojUser): config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() return config +@require_valid_user def get_user_notion_config(user: KhojUser): config = NotionConfig.objects.filter(user=user).first() return config @@ -435,6 +494,7 @@ def delete_user_requests(window: timedelta = timedelta(days=1)): return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete() +@arequire_valid_user async def aget_user_name(user: KhojUser): full_name = user.get_full_name() if not is_none_or_empty(full_name): @@ -458,6 +518,7 @@ async def set_text_content_config(user: KhojUser, object: Type[models.Model], up ) +@arequire_valid_user async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): config = await GithubConfig.objects.filter(user=user).afirst() @@ -592,8 +653,11 @@ class AgentAdapters: ) @staticmethod + @arequire_valid_user async def adelete_agent_by_slug(agent_slug: str, user: KhojUser): agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user) + if agent.creator != user: + return False async for entry in Entry.objects.filter(agent=agent).aiterator(): await entry.adelete() @@ -717,6 +781,7 @@ class AgentAdapters: return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst() @staticmethod + @arequire_valid_user async def aupdate_agent( user: KhojUser, name: str, @@ -817,10 +882,10 @@ class ConversationAdapters: ) @staticmethod + @require_valid_user def get_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None ) -> Optional[Conversation]: - check_valid_user(user) if conversation_id: conversation = ( Conversation.objects.filter(user=user, client=client_application, id=conversation_id) @@ -836,8 +901,8 @@ class ConversationAdapters: return conversation @staticmethod + @require_valid_user def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None): - check_valid_user(user) return ( Conversation.objects.filter(user=user, client=client_application) .prefetch_related("agent") @@ -845,10 +910,10 @@ class ConversationAdapters: ) @staticmethod + @arequire_valid_user async def aset_conversation_title( user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str ): - check_valid_user(user) conversation = await Conversation.objects.filter( user=user, client=client_application, id=conversation_id ).afirst() @@ -863,10 +928,10 @@ class ConversationAdapters: return Conversation.objects.filter(id=conversation_id).first() @staticmethod + @arequire_valid_user async def acreate_conversation_session( user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None ): - check_valid_user(user) if agent_slug: agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user) if agent is None: @@ -880,10 +945,10 @@ class ConversationAdapters: ) @staticmethod + @require_valid_user def create_conversation_session( user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None ): - check_valid_user(user) if agent_slug: agent = AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user) if agent is None: @@ -893,6 +958,7 @@ class ConversationAdapters: return Conversation.objects.create(user=user, client=client_application, agent=agent, title=title) @staticmethod + @arequire_valid_user async def aget_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, @@ -900,7 +966,6 @@ class ConversationAdapters: title: str = None, create_new: bool = False, ) -> Optional[Conversation]: - check_valid_user(user) if create_new: return await ConversationAdapters.acreate_conversation_session(user, client_application) @@ -918,17 +983,17 @@ class ConversationAdapters: ) @staticmethod + @arequire_valid_user async def adelete_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None ): - check_valid_user(user) if conversation_id: return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).adelete() return await Conversation.objects.filter(user=user, client=client_application).adelete() @staticmethod + @require_valid_user def has_any_conversation_config(user: KhojUser): - check_valid_user(user) return ChatModelOptions.objects.filter(user=user).exists() @staticmethod @@ -964,17 +1029,19 @@ class ConversationAdapters: return OpenAIProcessorConversationConfig.objects.filter().exists() @staticmethod + @arequire_valid_user async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int): config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst() - if not config or user is None: + if not config: return None new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) return new_config @staticmethod + @arequire_valid_user async def aset_user_voice_model(user: KhojUser, model_id: str): config = await VoiceModelOption.objects.filter(model_id=model_id).afirst() - if not config or user is None: + if not config: return None new_config = await UserVoiceModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) return new_config @@ -1156,10 +1223,10 @@ class ConversationAdapters: return enabled_scrapers @staticmethod + @require_valid_user def create_conversation_from_public_conversation( user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication ): - check_valid_user(user) scrubbed_title = public_conversation.title if public_conversation.title else public_conversation.slug if scrubbed_title: scrubbed_title = scrubbed_title.replace("-", " ") @@ -1173,6 +1240,7 @@ class ConversationAdapters: ) @staticmethod + @require_valid_user def save_conversation( user: KhojUser, conversation_log: dict, @@ -1180,7 +1248,6 @@ class ConversationAdapters: conversation_id: str = None, user_message: str = None, ): - check_valid_user(user) slug = user_message.strip()[:200] if user_message else None if conversation_id: conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first() @@ -1223,8 +1290,8 @@ class ConversationAdapters: return await SpeechToTextModelOptions.objects.filter().afirst() @staticmethod + @arequire_valid_user async def aget_conversation_starters(user: KhojUser, max_results=3): - check_valid_user(user) all_questions = [] if await ReflectiveQuestion.objects.filter(user=user).aexists(): all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=user).values_list)( @@ -1353,8 +1420,8 @@ class ConversationAdapters: return conversation.file_filters @staticmethod + @require_valid_user def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str): - check_valid_user(user) conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id) if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"): return False @@ -1372,28 +1439,28 @@ class FileObjectAdapters: file_object.save() @staticmethod + @require_valid_user def create_file_object(user: KhojUser, file_name: str, raw_text: str): - check_valid_user(user) return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text) @staticmethod + @require_valid_user def get_file_object_by_name(user: KhojUser, file_name: str): - check_valid_user(user) return FileObject.objects.filter(user=user, file_name=file_name).first() @staticmethod + @require_valid_user def get_all_file_objects(user: KhojUser): - check_valid_user(user) return FileObject.objects.filter(user=user).all() @staticmethod + @require_valid_user def delete_file_object_by_name(user: KhojUser, file_name: str): - check_valid_user(user) return FileObject.objects.filter(user=user, file_name=file_name).delete() @staticmethod + @require_valid_user def delete_all_file_objects(user: KhojUser): - check_valid_user(user) return FileObject.objects.filter(user=user).delete() @staticmethod @@ -1402,33 +1469,33 @@ class FileObjectAdapters: await file_object.asave() @staticmethod + @arequire_valid_user async def acreate_file_object(user: KhojUser, file_name: str, raw_text: str): - check_valid_user(user) return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text) @staticmethod + @arequire_valid_user async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): - check_valid_user(user) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent)) @staticmethod + @arequire_valid_user async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]): - check_valid_user(user) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names)) @staticmethod + @arequire_valid_user async def aget_all_file_objects(user: KhojUser): - check_valid_user(user) return await sync_to_async(list)(FileObject.objects.filter(user=user)) @staticmethod + @arequire_valid_user async def adelete_file_object_by_name(user: KhojUser, file_name: str): - check_valid_user(user) return await FileObject.objects.filter(user=user, file_name=file_name).adelete() @staticmethod + @arequire_valid_user async def adelete_all_file_objects(user: KhojUser): - check_valid_user(user) return await FileObject.objects.filter(user=user).adelete() @@ -1438,19 +1505,19 @@ class EntryAdapters: date_filter = DateFilter() @staticmethod + @require_valid_user def does_entry_exist(user: KhojUser, hashed_value: str) -> bool: - check_valid_user(user) return Entry.objects.filter(user=user, hashed_value=hashed_value).exists() @staticmethod + @require_valid_user def delete_entry_by_file(user: KhojUser, file_path: str): - check_valid_user(user) deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete() return deleted_count @staticmethod + @require_valid_user def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None): - check_valid_user(user) queryset = Entry.objects.filter(user=user) if file_type is not None: @@ -1462,8 +1529,8 @@ class EntryAdapters: return queryset @staticmethod + @require_valid_user def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): - check_valid_user(user) deleted_count = 0 queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) while queryset.exists(): @@ -1474,8 +1541,8 @@ class EntryAdapters: return deleted_count @staticmethod + @arequire_valid_user async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): - check_valid_user(user) deleted_count = 0 queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) while await queryset.aexists(): @@ -1486,13 +1553,13 @@ class EntryAdapters: return deleted_count @staticmethod + @require_valid_user def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): - check_valid_user(user) return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) @staticmethod + @require_valid_user def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]): - check_valid_user(user) Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete() @staticmethod @@ -1503,8 +1570,8 @@ class EntryAdapters: ) @staticmethod + @require_valid_user def user_has_entries(user: KhojUser): - check_valid_user(user) return Entry.objects.filter(user=user).exists() @staticmethod @@ -1512,8 +1579,8 @@ class EntryAdapters: return Entry.objects.filter(agent=agent).exists() @staticmethod + @arequire_valid_user async def auser_has_entries(user: KhojUser): - check_valid_user(user) return await Entry.objects.filter(user=user).aexists() @staticmethod @@ -1523,13 +1590,13 @@ class EntryAdapters: return await Entry.objects.filter(agent=agent).aexists() @staticmethod + @arequire_valid_user async def adelete_entry_by_file(user: KhojUser, file_path: str): - check_valid_user(user) return await Entry.objects.filter(user=user, file_path=file_path).adelete() @staticmethod + @arequire_valid_user async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000): - check_valid_user(user) deleted_count = 0 for i in range(0, len(filenames), batch_size): batch = filenames[i : i + batch_size] @@ -1547,8 +1614,8 @@ class EntryAdapters: ) @staticmethod + @require_valid_user def get_all_filenames_by_source(user: KhojUser, file_source: str): - check_valid_user(user) return ( Entry.objects.filter(user=user, file_source=file_source) .distinct("file_path") @@ -1556,8 +1623,8 @@ class EntryAdapters: ) @staticmethod + @require_valid_user def get_size_of_indexed_data_in_mb(user: KhojUser): - check_valid_user(user) entries = Entry.objects.filter(user=user).iterator() total_size = sum(sys.getsizeof(entry.compiled) for entry in entries) return total_size / 1024 / 1024 @@ -1654,13 +1721,13 @@ class EntryAdapters: return relevant_entries[:max_results] @staticmethod + @require_valid_user def get_unique_file_types(user: KhojUser): - check_valid_user(user) return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct() @staticmethod + @require_valid_user def get_unique_file_sources(user: KhojUser): - check_valid_user(user) return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all()