mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Convert required user param check into decorator. Use with more adapters
This commit is contained in:
parent
ff5c10c221
commit
10bca6fa8f
1 changed files with 113 additions and 46 deletions
|
@ -8,7 +8,17 @@ import secrets
|
||||||
import sys
|
import sys
|
||||||
from datetime import date, datetime, timedelta, timezone
|
from datetime import date, datetime, timedelta, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Iterable, List, Optional, Type
|
from functools import wraps
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Coroutine,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
ParamSpec,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
import cron_descriptor
|
import cron_descriptor
|
||||||
from apscheduler.job import Job
|
from apscheduler.job import Job
|
||||||
|
@ -80,6 +90,45 @@ class SubscriptionState(Enum):
|
||||||
INVALID = "invalid"
|
INVALID = "invalid"
|
||||||
|
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def require_valid_user(func: Callable[P, T]) -> Callable[P, T]:
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
|
# Extract user from args/kwargs
|
||||||
|
user = next((arg for arg in args if isinstance(arg, KhojUser)), None)
|
||||||
|
if not user:
|
||||||
|
user = next((val for val in kwargs.values() if isinstance(val, KhojUser)), None)
|
||||||
|
|
||||||
|
# Throw error if user is not found
|
||||||
|
if not user:
|
||||||
|
raise ValueError("Khoj user argument required but not provided.")
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def arequire_valid_user(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Coroutine[Any, Any, T]]:
|
||||||
|
@wraps(func)
|
||||||
|
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
|
# Extract user from args/kwargs
|
||||||
|
user = next((arg for arg in args if isinstance(arg, KhojUser)), None)
|
||||||
|
if not user:
|
||||||
|
user = next((v for v in kwargs.values() if isinstance(v, KhojUser)), None)
|
||||||
|
|
||||||
|
# Throw error if user is not found
|
||||||
|
if not user:
|
||||||
|
raise ValueError("Khoj user argument required but not provided.")
|
||||||
|
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
return async_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def set_notion_config(token: str, user: KhojUser):
|
async def set_notion_config(token: str, user: KhojUser):
|
||||||
notion_config = await NotionConfig.objects.filter(user=user).afirst()
|
notion_config = await NotionConfig.objects.filter(user=user).afirst()
|
||||||
if not notion_config:
|
if not notion_config:
|
||||||
|
@ -90,6 +139,7 @@ async def set_notion_config(token: str, user: KhojUser):
|
||||||
return notion_config
|
return notion_config
|
||||||
|
|
||||||
|
|
||||||
|
@require_valid_user
|
||||||
def create_khoj_token(user: KhojUser, name=None):
|
def create_khoj_token(user: KhojUser, name=None):
|
||||||
"Create Khoj API key for user"
|
"Create Khoj API key for user"
|
||||||
token = f"kk-{secrets.token_urlsafe(32)}"
|
token = f"kk-{secrets.token_urlsafe(32)}"
|
||||||
|
@ -97,6 +147,7 @@ def create_khoj_token(user: KhojUser, name=None):
|
||||||
return KhojApiUser.objects.create(token=token, user=user, name=name)
|
return KhojApiUser.objects.create(token=token, user=user, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def acreate_khoj_token(user: KhojUser, name=None):
|
async def acreate_khoj_token(user: KhojUser, name=None):
|
||||||
"Create Khoj API key for user"
|
"Create Khoj API key for user"
|
||||||
token = f"kk-{secrets.token_urlsafe(32)}"
|
token = f"kk-{secrets.token_urlsafe(32)}"
|
||||||
|
@ -104,11 +155,13 @@ async def acreate_khoj_token(user: KhojUser, name=None):
|
||||||
return await KhojApiUser.objects.acreate(token=token, user=user, name=name)
|
return await KhojApiUser.objects.acreate(token=token, user=user, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
@require_valid_user
|
||||||
def get_khoj_tokens(user: KhojUser):
|
def get_khoj_tokens(user: KhojUser):
|
||||||
"Get all Khoj API keys for user"
|
"Get all Khoj API keys for user"
|
||||||
return list(KhojApiUser.objects.filter(user=user))
|
return list(KhojApiUser.objects.filter(user=user))
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def delete_khoj_token(user: KhojUser, token: str):
|
async def delete_khoj_token(user: KhojUser, token: str):
|
||||||
"Delete Khoj API Key for user"
|
"Delete Khoj API Key for user"
|
||||||
await KhojApiUser.objects.filter(token=token, user=user).adelete()
|
await KhojApiUser.objects.filter(token=token, user=user).adelete()
|
||||||
|
@ -132,6 +185,7 @@ async def aget_or_create_user_by_phone_number(phone_number: str) -> tuple[KhojUs
|
||||||
return user, is_new
|
return user, is_new
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
|
async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
|
||||||
if is_none_or_empty(phone_number):
|
if is_none_or_empty(phone_number):
|
||||||
return None
|
return None
|
||||||
|
@ -155,6 +209,7 @@ async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def aremove_phone_number(user: KhojUser) -> KhojUser:
|
async def aremove_phone_number(user: KhojUser) -> KhojUser:
|
||||||
user.phone_number = None
|
user.phone_number = None
|
||||||
user.verified_phone_number = False
|
user.verified_phone_number = False
|
||||||
|
@ -192,6 +247,7 @@ async def aget_or_create_user_by_email(email: str) -> tuple[KhojUser, bool]:
|
||||||
return user, is_new
|
return user, is_new
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def astart_trial_subscription(user: KhojUser) -> Subscription:
|
async def astart_trial_subscription(user: KhojUser) -> Subscription:
|
||||||
subscription = await Subscription.objects.filter(user=user).afirst()
|
subscription = await Subscription.objects.filter(user=user).afirst()
|
||||||
if not subscription:
|
if not subscription:
|
||||||
|
@ -246,6 +302,7 @@ async def create_user_by_google_token(token: dict) -> KhojUser:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@require_valid_user
|
||||||
def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
|
def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
|
||||||
user.first_name = first_name
|
user.first_name = first_name
|
||||||
user.last_name = last_name
|
user.last_name = last_name
|
||||||
|
@ -253,6 +310,7 @@ def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@require_valid_user
|
||||||
def get_user_name(user: KhojUser):
|
def get_user_name(user: KhojUser):
|
||||||
full_name = user.get_full_name()
|
full_name = user.get_full_name()
|
||||||
if not is_none_or_empty(full_name):
|
if not is_none_or_empty(full_name):
|
||||||
|
@ -264,6 +322,7 @@ def get_user_name(user: KhojUser):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@require_valid_user
|
||||||
def get_user_photo(user: KhojUser):
|
def get_user_photo(user: KhojUser):
|
||||||
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
|
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
|
||||||
if google_profile:
|
if google_profile:
|
||||||
|
@ -327,6 +386,7 @@ def get_user_subscription_state(email: str) -> str:
|
||||||
return subscription_to_state(user_subscription)
|
return subscription_to_state(user_subscription)
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def aget_user_subscription_state(user: KhojUser) -> str:
|
async def aget_user_subscription_state(user: KhojUser) -> str:
|
||||||
"""Get subscription state of user
|
"""Get subscription state of user
|
||||||
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
|
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
|
||||||
|
@ -335,6 +395,7 @@ async def aget_user_subscription_state(user: KhojUser) -> str:
|
||||||
return await sync_to_async(subscription_to_state)(user_subscription)
|
return await sync_to_async(subscription_to_state)(user_subscription)
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def ais_user_subscribed(user: KhojUser) -> bool:
|
async def ais_user_subscribed(user: KhojUser) -> bool:
|
||||||
"""
|
"""
|
||||||
Get whether the user is subscribed
|
Get whether the user is subscribed
|
||||||
|
@ -351,6 +412,7 @@ async def ais_user_subscribed(user: KhojUser) -> bool:
|
||||||
return subscribed
|
return subscribed
|
||||||
|
|
||||||
|
|
||||||
|
@require_valid_user
|
||||||
def is_user_subscribed(user: KhojUser) -> bool:
|
def is_user_subscribed(user: KhojUser) -> bool:
|
||||||
"""
|
"""
|
||||||
Get whether the user is subscribed
|
Get whether the user is subscribed
|
||||||
|
@ -416,16 +478,13 @@ def get_all_users() -> BaseManager[KhojUser]:
|
||||||
return KhojUser.objects.all()
|
return KhojUser.objects.all()
|
||||||
|
|
||||||
|
|
||||||
def check_valid_user(user: KhojUser | None):
|
@require_valid_user
|
||||||
if not user:
|
|
||||||
raise ValueError("User not found")
|
|
||||||
|
|
||||||
|
|
||||||
def get_user_github_config(user: KhojUser):
|
def get_user_github_config(user: KhojUser):
|
||||||
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@require_valid_user
|
||||||
def get_user_notion_config(user: KhojUser):
|
def get_user_notion_config(user: KhojUser):
|
||||||
config = NotionConfig.objects.filter(user=user).first()
|
config = NotionConfig.objects.filter(user=user).first()
|
||||||
return config
|
return config
|
||||||
|
@ -435,6 +494,7 @@ def delete_user_requests(window: timedelta = timedelta(days=1)):
|
||||||
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
|
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def aget_user_name(user: KhojUser):
|
async def aget_user_name(user: KhojUser):
|
||||||
full_name = user.get_full_name()
|
full_name = user.get_full_name()
|
||||||
if not is_none_or_empty(full_name):
|
if not is_none_or_empty(full_name):
|
||||||
|
@ -458,6 +518,7 @@ async def set_text_content_config(user: KhojUser, object: Type[models.Model], up
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
||||||
config = await GithubConfig.objects.filter(user=user).afirst()
|
config = await GithubConfig.objects.filter(user=user).afirst()
|
||||||
|
|
||||||
|
@ -592,8 +653,11 @@ class AgentAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def adelete_agent_by_slug(agent_slug: str, user: KhojUser):
|
async def adelete_agent_by_slug(agent_slug: str, user: KhojUser):
|
||||||
agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
|
agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
|
||||||
|
if agent.creator != user:
|
||||||
|
return False
|
||||||
|
|
||||||
async for entry in Entry.objects.filter(agent=agent).aiterator():
|
async for entry in Entry.objects.filter(agent=agent).aiterator():
|
||||||
await entry.adelete()
|
await entry.adelete()
|
||||||
|
@ -717,6 +781,7 @@ class AgentAdapters:
|
||||||
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
|
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def aupdate_agent(
|
async def aupdate_agent(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
name: str,
|
name: str,
|
||||||
|
@ -817,10 +882,10 @@ class ConversationAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_conversation_by_user(
|
def get_conversation_by_user(
|
||||||
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
|
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
|
||||||
) -> Optional[Conversation]:
|
) -> Optional[Conversation]:
|
||||||
check_valid_user(user)
|
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
conversation = (
|
conversation = (
|
||||||
Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
|
Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
|
||||||
|
@ -836,8 +901,8 @@ class ConversationAdapters:
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None):
|
def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None):
|
||||||
check_valid_user(user)
|
|
||||||
return (
|
return (
|
||||||
Conversation.objects.filter(user=user, client=client_application)
|
Conversation.objects.filter(user=user, client=client_application)
|
||||||
.prefetch_related("agent")
|
.prefetch_related("agent")
|
||||||
|
@ -845,10 +910,10 @@ class ConversationAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def aset_conversation_title(
|
async def aset_conversation_title(
|
||||||
user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str
|
user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str
|
||||||
):
|
):
|
||||||
check_valid_user(user)
|
|
||||||
conversation = await Conversation.objects.filter(
|
conversation = await Conversation.objects.filter(
|
||||||
user=user, client=client_application, id=conversation_id
|
user=user, client=client_application, id=conversation_id
|
||||||
).afirst()
|
).afirst()
|
||||||
|
@ -863,10 +928,10 @@ class ConversationAdapters:
|
||||||
return Conversation.objects.filter(id=conversation_id).first()
|
return Conversation.objects.filter(id=conversation_id).first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def acreate_conversation_session(
|
async def acreate_conversation_session(
|
||||||
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
|
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
|
||||||
):
|
):
|
||||||
check_valid_user(user)
|
|
||||||
if agent_slug:
|
if agent_slug:
|
||||||
agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
|
agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
|
||||||
if agent is None:
|
if agent is None:
|
||||||
|
@ -880,10 +945,10 @@ class ConversationAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def create_conversation_session(
|
def create_conversation_session(
|
||||||
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
|
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
|
||||||
):
|
):
|
||||||
check_valid_user(user)
|
|
||||||
if agent_slug:
|
if agent_slug:
|
||||||
agent = AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
|
agent = AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
|
||||||
if agent is None:
|
if agent is None:
|
||||||
|
@ -893,6 +958,7 @@ class ConversationAdapters:
|
||||||
return Conversation.objects.create(user=user, client=client_application, agent=agent, title=title)
|
return Conversation.objects.create(user=user, client=client_application, agent=agent, title=title)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def aget_conversation_by_user(
|
async def aget_conversation_by_user(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
client_application: ClientApplication = None,
|
client_application: ClientApplication = None,
|
||||||
|
@ -900,7 +966,6 @@ class ConversationAdapters:
|
||||||
title: str = None,
|
title: str = None,
|
||||||
create_new: bool = False,
|
create_new: bool = False,
|
||||||
) -> Optional[Conversation]:
|
) -> Optional[Conversation]:
|
||||||
check_valid_user(user)
|
|
||||||
if create_new:
|
if create_new:
|
||||||
return await ConversationAdapters.acreate_conversation_session(user, client_application)
|
return await ConversationAdapters.acreate_conversation_session(user, client_application)
|
||||||
|
|
||||||
|
@ -918,17 +983,17 @@ class ConversationAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def adelete_conversation_by_user(
|
async def adelete_conversation_by_user(
|
||||||
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
|
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
|
||||||
):
|
):
|
||||||
check_valid_user(user)
|
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).adelete()
|
return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).adelete()
|
||||||
return await Conversation.objects.filter(user=user, client=client_application).adelete()
|
return await Conversation.objects.filter(user=user, client=client_application).adelete()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def has_any_conversation_config(user: KhojUser):
|
def has_any_conversation_config(user: KhojUser):
|
||||||
check_valid_user(user)
|
|
||||||
return ChatModelOptions.objects.filter(user=user).exists()
|
return ChatModelOptions.objects.filter(user=user).exists()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -964,17 +1029,19 @@ class ConversationAdapters:
|
||||||
return OpenAIProcessorConversationConfig.objects.filter().exists()
|
return OpenAIProcessorConversationConfig.objects.filter().exists()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
|
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
|
||||||
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
|
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
|
||||||
if not config or user is None:
|
if not config:
|
||||||
return None
|
return None
|
||||||
new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
||||||
return new_config
|
return new_config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def aset_user_voice_model(user: KhojUser, model_id: str):
|
async def aset_user_voice_model(user: KhojUser, model_id: str):
|
||||||
config = await VoiceModelOption.objects.filter(model_id=model_id).afirst()
|
config = await VoiceModelOption.objects.filter(model_id=model_id).afirst()
|
||||||
if not config or user is None:
|
if not config:
|
||||||
return None
|
return None
|
||||||
new_config = await UserVoiceModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
new_config = await UserVoiceModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
||||||
return new_config
|
return new_config
|
||||||
|
@ -1156,10 +1223,10 @@ class ConversationAdapters:
|
||||||
return enabled_scrapers
|
return enabled_scrapers
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def create_conversation_from_public_conversation(
|
def create_conversation_from_public_conversation(
|
||||||
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication
|
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication
|
||||||
):
|
):
|
||||||
check_valid_user(user)
|
|
||||||
scrubbed_title = public_conversation.title if public_conversation.title else public_conversation.slug
|
scrubbed_title = public_conversation.title if public_conversation.title else public_conversation.slug
|
||||||
if scrubbed_title:
|
if scrubbed_title:
|
||||||
scrubbed_title = scrubbed_title.replace("-", " ")
|
scrubbed_title = scrubbed_title.replace("-", " ")
|
||||||
|
@ -1173,6 +1240,7 @@ class ConversationAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def save_conversation(
|
def save_conversation(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
conversation_log: dict,
|
conversation_log: dict,
|
||||||
|
@ -1180,7 +1248,6 @@ class ConversationAdapters:
|
||||||
conversation_id: str = None,
|
conversation_id: str = None,
|
||||||
user_message: str = None,
|
user_message: str = None,
|
||||||
):
|
):
|
||||||
check_valid_user(user)
|
|
||||||
slug = user_message.strip()[:200] if user_message else None
|
slug = user_message.strip()[:200] if user_message else None
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first()
|
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first()
|
||||||
|
@ -1223,8 +1290,8 @@ class ConversationAdapters:
|
||||||
return await SpeechToTextModelOptions.objects.filter().afirst()
|
return await SpeechToTextModelOptions.objects.filter().afirst()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def aget_conversation_starters(user: KhojUser, max_results=3):
|
async def aget_conversation_starters(user: KhojUser, max_results=3):
|
||||||
check_valid_user(user)
|
|
||||||
all_questions = []
|
all_questions = []
|
||||||
if await ReflectiveQuestion.objects.filter(user=user).aexists():
|
if await ReflectiveQuestion.objects.filter(user=user).aexists():
|
||||||
all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=user).values_list)(
|
all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=user).values_list)(
|
||||||
|
@ -1353,8 +1420,8 @@ class ConversationAdapters:
|
||||||
return conversation.file_filters
|
return conversation.file_filters
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str):
|
def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str):
|
||||||
check_valid_user(user)
|
|
||||||
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
|
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
|
||||||
if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"):
|
if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"):
|
||||||
return False
|
return False
|
||||||
|
@ -1372,28 +1439,28 @@ class FileObjectAdapters:
|
||||||
file_object.save()
|
file_object.save()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def create_file_object(user: KhojUser, file_name: str, raw_text: str):
|
def create_file_object(user: KhojUser, file_name: str, raw_text: str):
|
||||||
check_valid_user(user)
|
|
||||||
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
|
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_file_object_by_name(user: KhojUser, file_name: str):
|
def get_file_object_by_name(user: KhojUser, file_name: str):
|
||||||
check_valid_user(user)
|
|
||||||
return FileObject.objects.filter(user=user, file_name=file_name).first()
|
return FileObject.objects.filter(user=user, file_name=file_name).first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_all_file_objects(user: KhojUser):
|
def get_all_file_objects(user: KhojUser):
|
||||||
check_valid_user(user)
|
|
||||||
return FileObject.objects.filter(user=user).all()
|
return FileObject.objects.filter(user=user).all()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def delete_file_object_by_name(user: KhojUser, file_name: str):
|
def delete_file_object_by_name(user: KhojUser, file_name: str):
|
||||||
check_valid_user(user)
|
|
||||||
return FileObject.objects.filter(user=user, file_name=file_name).delete()
|
return FileObject.objects.filter(user=user, file_name=file_name).delete()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def delete_all_file_objects(user: KhojUser):
|
def delete_all_file_objects(user: KhojUser):
|
||||||
check_valid_user(user)
|
|
||||||
return FileObject.objects.filter(user=user).delete()
|
return FileObject.objects.filter(user=user).delete()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -1402,33 +1469,33 @@ class FileObjectAdapters:
|
||||||
await file_object.asave()
|
await file_object.asave()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def acreate_file_object(user: KhojUser, file_name: str, raw_text: str):
|
async def acreate_file_object(user: KhojUser, file_name: str, raw_text: str):
|
||||||
check_valid_user(user)
|
|
||||||
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
|
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
|
async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
|
||||||
check_valid_user(user)
|
|
||||||
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
|
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]):
|
async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]):
|
||||||
check_valid_user(user)
|
|
||||||
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names))
|
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def aget_all_file_objects(user: KhojUser):
|
async def aget_all_file_objects(user: KhojUser):
|
||||||
check_valid_user(user)
|
|
||||||
return await sync_to_async(list)(FileObject.objects.filter(user=user))
|
return await sync_to_async(list)(FileObject.objects.filter(user=user))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def adelete_file_object_by_name(user: KhojUser, file_name: str):
|
async def adelete_file_object_by_name(user: KhojUser, file_name: str):
|
||||||
check_valid_user(user)
|
|
||||||
return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
|
return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def adelete_all_file_objects(user: KhojUser):
|
async def adelete_all_file_objects(user: KhojUser):
|
||||||
check_valid_user(user)
|
|
||||||
return await FileObject.objects.filter(user=user).adelete()
|
return await FileObject.objects.filter(user=user).adelete()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1438,19 +1505,19 @@ class EntryAdapters:
|
||||||
date_filter = DateFilter()
|
date_filter = DateFilter()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
|
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
|
||||||
check_valid_user(user)
|
|
||||||
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
|
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def delete_entry_by_file(user: KhojUser, file_path: str):
|
def delete_entry_by_file(user: KhojUser, file_path: str):
|
||||||
check_valid_user(user)
|
|
||||||
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
|
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None):
|
def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None):
|
||||||
check_valid_user(user)
|
|
||||||
queryset = Entry.objects.filter(user=user)
|
queryset = Entry.objects.filter(user=user)
|
||||||
|
|
||||||
if file_type is not None:
|
if file_type is not None:
|
||||||
|
@ -1462,8 +1529,8 @@ class EntryAdapters:
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
|
def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
|
||||||
check_valid_user(user)
|
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
|
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
|
||||||
while queryset.exists():
|
while queryset.exists():
|
||||||
|
@ -1474,8 +1541,8 @@ class EntryAdapters:
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
|
async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
|
||||||
check_valid_user(user)
|
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
|
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
|
||||||
while await queryset.aexists():
|
while await queryset.aexists():
|
||||||
|
@ -1486,13 +1553,13 @@ class EntryAdapters:
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
|
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
|
||||||
check_valid_user(user)
|
|
||||||
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
|
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
|
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
|
||||||
check_valid_user(user)
|
|
||||||
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
|
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -1503,8 +1570,8 @@ class EntryAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def user_has_entries(user: KhojUser):
|
def user_has_entries(user: KhojUser):
|
||||||
check_valid_user(user)
|
|
||||||
return Entry.objects.filter(user=user).exists()
|
return Entry.objects.filter(user=user).exists()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -1512,8 +1579,8 @@ class EntryAdapters:
|
||||||
return Entry.objects.filter(agent=agent).exists()
|
return Entry.objects.filter(agent=agent).exists()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def auser_has_entries(user: KhojUser):
|
async def auser_has_entries(user: KhojUser):
|
||||||
check_valid_user(user)
|
|
||||||
return await Entry.objects.filter(user=user).aexists()
|
return await Entry.objects.filter(user=user).aexists()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -1523,13 +1590,13 @@ class EntryAdapters:
|
||||||
return await Entry.objects.filter(agent=agent).aexists()
|
return await Entry.objects.filter(agent=agent).aexists()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def adelete_entry_by_file(user: KhojUser, file_path: str):
|
async def adelete_entry_by_file(user: KhojUser, file_path: str):
|
||||||
check_valid_user(user)
|
|
||||||
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
|
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000):
|
async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000):
|
||||||
check_valid_user(user)
|
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
for i in range(0, len(filenames), batch_size):
|
for i in range(0, len(filenames), batch_size):
|
||||||
batch = filenames[i : i + batch_size]
|
batch = filenames[i : i + batch_size]
|
||||||
|
@ -1547,8 +1614,8 @@ class EntryAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_all_filenames_by_source(user: KhojUser, file_source: str):
|
def get_all_filenames_by_source(user: KhojUser, file_source: str):
|
||||||
check_valid_user(user)
|
|
||||||
return (
|
return (
|
||||||
Entry.objects.filter(user=user, file_source=file_source)
|
Entry.objects.filter(user=user, file_source=file_source)
|
||||||
.distinct("file_path")
|
.distinct("file_path")
|
||||||
|
@ -1556,8 +1623,8 @@ class EntryAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_size_of_indexed_data_in_mb(user: KhojUser):
|
def get_size_of_indexed_data_in_mb(user: KhojUser):
|
||||||
check_valid_user(user)
|
|
||||||
entries = Entry.objects.filter(user=user).iterator()
|
entries = Entry.objects.filter(user=user).iterator()
|
||||||
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
|
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
|
||||||
return total_size / 1024 / 1024
|
return total_size / 1024 / 1024
|
||||||
|
@ -1654,13 +1721,13 @@ class EntryAdapters:
|
||||||
return relevant_entries[:max_results]
|
return relevant_entries[:max_results]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_unique_file_types(user: KhojUser):
|
def get_unique_file_types(user: KhojUser):
|
||||||
check_valid_user(user)
|
|
||||||
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
|
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_unique_file_sources(user: KhojUser):
|
def get_unique_file_sources(user: KhojUser):
|
||||||
check_valid_user(user)
|
|
||||||
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all()
|
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue