mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Do not CRUD on entries, files & conversations in DB for null user (#958)
Increase defense-in-depth by reducing paths to create, read, update or delete entries, files and conversations in DB when user is unset.
This commit is contained in:
commit
ba2471dc02
19 changed files with 158 additions and 73 deletions
|
@ -253,7 +253,7 @@ def configure_server(
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
|
|
||||||
if not init:
|
if not init:
|
||||||
initialize_content(regenerate, search_type, user)
|
initialize_content(user, regenerate, search_type)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load some search models: {e}", exc_info=True)
|
logger.error(f"Failed to load some search models: {e}", exc_info=True)
|
||||||
|
@ -263,17 +263,17 @@ def setup_default_agent(user: KhojUser):
|
||||||
AgentAdapters.create_default_agent(user)
|
AgentAdapters.create_default_agent(user)
|
||||||
|
|
||||||
|
|
||||||
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None):
|
def initialize_content(user: KhojUser, regenerate: bool, search_type: Optional[SearchType] = None):
|
||||||
# Initialize Content from Config
|
# Initialize Content from Config
|
||||||
if state.search_models:
|
if state.search_models:
|
||||||
try:
|
try:
|
||||||
logger.info("📬 Updating content index...")
|
logger.info("📬 Updating content index...")
|
||||||
all_files = collect_files(user=user)
|
all_files = collect_files(user=user)
|
||||||
status = configure_content(
|
status = configure_content(
|
||||||
|
user,
|
||||||
all_files,
|
all_files,
|
||||||
regenerate,
|
regenerate,
|
||||||
search_type,
|
search_type,
|
||||||
user=user,
|
|
||||||
)
|
)
|
||||||
if not status:
|
if not status:
|
||||||
raise RuntimeError("Failed to update content index")
|
raise RuntimeError("Failed to update content index")
|
||||||
|
@ -338,9 +338,7 @@ def configure_middleware(app):
|
||||||
def update_content_index():
|
def update_content_index():
|
||||||
for user in get_all_users():
|
for user in get_all_users():
|
||||||
all_files = collect_files(user=user)
|
all_files = collect_files(user=user)
|
||||||
success = configure_content(all_files, user=user)
|
success = configure_content(user, all_files)
|
||||||
all_files = collect_files(user=None)
|
|
||||||
success = configure_content(all_files, user=None)
|
|
||||||
if not success:
|
if not success:
|
||||||
raise RuntimeError("Failed to update content index")
|
raise RuntimeError("Failed to update content index")
|
||||||
logger.info("📪 Content index updated via Scheduler")
|
logger.info("📪 Content index updated via Scheduler")
|
||||||
|
|
|
@ -8,13 +8,22 @@ import secrets
|
||||||
import sys
|
import sys
|
||||||
from datetime import date, datetime, timedelta, timezone
|
from datetime import date, datetime, timedelta, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Iterable, List, Optional, Type
|
from functools import wraps
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Coroutine,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
ParamSpec,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
import cron_descriptor
|
import cron_descriptor
|
||||||
from apscheduler.job import Job
|
from apscheduler.job import Job
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from django.contrib.sessions.backends.db import SessionStore
|
from django.contrib.sessions.backends.db import SessionStore
|
||||||
from django.db import models
|
|
||||||
from django.db.models import Prefetch, Q
|
from django.db.models import Prefetch, Q
|
||||||
from django.db.models.manager import BaseManager
|
from django.db.models.manager import BaseManager
|
||||||
from django.db.utils import IntegrityError
|
from django.db.utils import IntegrityError
|
||||||
|
@ -28,7 +37,6 @@ from khoj.database.models import (
|
||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
ClientApplication,
|
ClientApplication,
|
||||||
Conversation,
|
Conversation,
|
||||||
DataStore,
|
|
||||||
Entry,
|
Entry,
|
||||||
FileObject,
|
FileObject,
|
||||||
GithubConfig,
|
GithubConfig,
|
||||||
|
@ -80,6 +88,45 @@ class SubscriptionState(Enum):
|
||||||
INVALID = "invalid"
|
INVALID = "invalid"
|
||||||
|
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def require_valid_user(func: Callable[P, T]) -> Callable[P, T]:
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
|
# Extract user from args/kwargs
|
||||||
|
user = next((arg for arg in args if isinstance(arg, KhojUser)), None)
|
||||||
|
if not user:
|
||||||
|
user = next((val for val in kwargs.values() if isinstance(val, KhojUser)), None)
|
||||||
|
|
||||||
|
# Throw error if user is not found
|
||||||
|
if not user:
|
||||||
|
raise ValueError("Khoj user argument required but not provided.")
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def arequire_valid_user(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Coroutine[Any, Any, T]]:
|
||||||
|
@wraps(func)
|
||||||
|
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
|
# Extract user from args/kwargs
|
||||||
|
user = next((arg for arg in args if isinstance(arg, KhojUser)), None)
|
||||||
|
if not user:
|
||||||
|
user = next((v for v in kwargs.values() if isinstance(v, KhojUser)), None)
|
||||||
|
|
||||||
|
# Throw error if user is not found
|
||||||
|
if not user:
|
||||||
|
raise ValueError("Khoj user argument required but not provided.")
|
||||||
|
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
return async_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def set_notion_config(token: str, user: KhojUser):
|
async def set_notion_config(token: str, user: KhojUser):
|
||||||
notion_config = await NotionConfig.objects.filter(user=user).afirst()
|
notion_config = await NotionConfig.objects.filter(user=user).afirst()
|
||||||
if not notion_config:
|
if not notion_config:
|
||||||
|
@ -90,6 +137,7 @@ async def set_notion_config(token: str, user: KhojUser):
|
||||||
return notion_config
|
return notion_config
|
||||||
|
|
||||||
|
|
||||||
|
@require_valid_user
|
||||||
def create_khoj_token(user: KhojUser, name=None):
|
def create_khoj_token(user: KhojUser, name=None):
|
||||||
"Create Khoj API key for user"
|
"Create Khoj API key for user"
|
||||||
token = f"kk-{secrets.token_urlsafe(32)}"
|
token = f"kk-{secrets.token_urlsafe(32)}"
|
||||||
|
@ -97,6 +145,7 @@ def create_khoj_token(user: KhojUser, name=None):
|
||||||
return KhojApiUser.objects.create(token=token, user=user, name=name)
|
return KhojApiUser.objects.create(token=token, user=user, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def acreate_khoj_token(user: KhojUser, name=None):
|
async def acreate_khoj_token(user: KhojUser, name=None):
|
||||||
"Create Khoj API key for user"
|
"Create Khoj API key for user"
|
||||||
token = f"kk-{secrets.token_urlsafe(32)}"
|
token = f"kk-{secrets.token_urlsafe(32)}"
|
||||||
|
@ -104,11 +153,13 @@ async def acreate_khoj_token(user: KhojUser, name=None):
|
||||||
return await KhojApiUser.objects.acreate(token=token, user=user, name=name)
|
return await KhojApiUser.objects.acreate(token=token, user=user, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
@require_valid_user
|
||||||
def get_khoj_tokens(user: KhojUser):
|
def get_khoj_tokens(user: KhojUser):
|
||||||
"Get all Khoj API keys for user"
|
"Get all Khoj API keys for user"
|
||||||
return list(KhojApiUser.objects.filter(user=user))
|
return list(KhojApiUser.objects.filter(user=user))
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def delete_khoj_token(user: KhojUser, token: str):
|
async def delete_khoj_token(user: KhojUser, token: str):
|
||||||
"Delete Khoj API Key for user"
|
"Delete Khoj API Key for user"
|
||||||
await KhojApiUser.objects.filter(token=token, user=user).adelete()
|
await KhojApiUser.objects.filter(token=token, user=user).adelete()
|
||||||
|
@ -132,6 +183,7 @@ async def aget_or_create_user_by_phone_number(phone_number: str) -> tuple[KhojUs
|
||||||
return user, is_new
|
return user, is_new
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
|
async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
|
||||||
if is_none_or_empty(phone_number):
|
if is_none_or_empty(phone_number):
|
||||||
return None
|
return None
|
||||||
|
@ -155,6 +207,7 @@ async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def aremove_phone_number(user: KhojUser) -> KhojUser:
|
async def aremove_phone_number(user: KhojUser) -> KhojUser:
|
||||||
user.phone_number = None
|
user.phone_number = None
|
||||||
user.verified_phone_number = False
|
user.verified_phone_number = False
|
||||||
|
@ -192,6 +245,7 @@ async def aget_or_create_user_by_email(email: str) -> tuple[KhojUser, bool]:
|
||||||
return user, is_new
|
return user, is_new
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def astart_trial_subscription(user: KhojUser) -> Subscription:
|
async def astart_trial_subscription(user: KhojUser) -> Subscription:
|
||||||
subscription = await Subscription.objects.filter(user=user).afirst()
|
subscription = await Subscription.objects.filter(user=user).afirst()
|
||||||
if not subscription:
|
if not subscription:
|
||||||
|
@ -246,6 +300,7 @@ async def create_user_by_google_token(token: dict) -> KhojUser:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@require_valid_user
|
||||||
def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
|
def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
|
||||||
user.first_name = first_name
|
user.first_name = first_name
|
||||||
user.last_name = last_name
|
user.last_name = last_name
|
||||||
|
@ -253,6 +308,7 @@ def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@require_valid_user
|
||||||
def get_user_name(user: KhojUser):
|
def get_user_name(user: KhojUser):
|
||||||
full_name = user.get_full_name()
|
full_name = user.get_full_name()
|
||||||
if not is_none_or_empty(full_name):
|
if not is_none_or_empty(full_name):
|
||||||
|
@ -264,6 +320,7 @@ def get_user_name(user: KhojUser):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@require_valid_user
|
||||||
def get_user_photo(user: KhojUser):
|
def get_user_photo(user: KhojUser):
|
||||||
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
|
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
|
||||||
if google_profile:
|
if google_profile:
|
||||||
|
@ -327,6 +384,7 @@ def get_user_subscription_state(email: str) -> str:
|
||||||
return subscription_to_state(user_subscription)
|
return subscription_to_state(user_subscription)
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def aget_user_subscription_state(user: KhojUser) -> str:
|
async def aget_user_subscription_state(user: KhojUser) -> str:
|
||||||
"""Get subscription state of user
|
"""Get subscription state of user
|
||||||
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
|
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
|
||||||
|
@ -335,6 +393,7 @@ async def aget_user_subscription_state(user: KhojUser) -> str:
|
||||||
return await sync_to_async(subscription_to_state)(user_subscription)
|
return await sync_to_async(subscription_to_state)(user_subscription)
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def ais_user_subscribed(user: KhojUser) -> bool:
|
async def ais_user_subscribed(user: KhojUser) -> bool:
|
||||||
"""
|
"""
|
||||||
Get whether the user is subscribed
|
Get whether the user is subscribed
|
||||||
|
@ -351,6 +410,7 @@ async def ais_user_subscribed(user: KhojUser) -> bool:
|
||||||
return subscribed
|
return subscribed
|
||||||
|
|
||||||
|
|
||||||
|
@require_valid_user
|
||||||
def is_user_subscribed(user: KhojUser) -> bool:
|
def is_user_subscribed(user: KhojUser) -> bool:
|
||||||
"""
|
"""
|
||||||
Get whether the user is subscribed
|
Get whether the user is subscribed
|
||||||
|
@ -416,11 +476,13 @@ def get_all_users() -> BaseManager[KhojUser]:
|
||||||
return KhojUser.objects.all()
|
return KhojUser.objects.all()
|
||||||
|
|
||||||
|
|
||||||
|
@require_valid_user
|
||||||
def get_user_github_config(user: KhojUser):
|
def get_user_github_config(user: KhojUser):
|
||||||
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@require_valid_user
|
||||||
def get_user_notion_config(user: KhojUser):
|
def get_user_notion_config(user: KhojUser):
|
||||||
config = NotionConfig.objects.filter(user=user).first()
|
config = NotionConfig.objects.filter(user=user).first()
|
||||||
return config
|
return config
|
||||||
|
@ -430,6 +492,7 @@ def delete_user_requests(window: timedelta = timedelta(days=1)):
|
||||||
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
|
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
|
||||||
|
|
||||||
|
|
||||||
|
@arequire_valid_user
|
||||||
async def aget_user_name(user: KhojUser):
|
async def aget_user_name(user: KhojUser):
|
||||||
full_name = user.get_full_name()
|
full_name = user.get_full_name()
|
||||||
if not is_none_or_empty(full_name):
|
if not is_none_or_empty(full_name):
|
||||||
|
@ -441,18 +504,7 @@ async def aget_user_name(user: KhojUser):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
|
@arequire_valid_user
|
||||||
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
|
|
||||||
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
|
|
||||||
await object.objects.filter(user=user).adelete()
|
|
||||||
await object.objects.acreate(
|
|
||||||
input_files=deduped_files,
|
|
||||||
input_filter=deduped_filters,
|
|
||||||
index_heading_entries=updated_config.index_heading_entries,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
||||||
config = await GithubConfig.objects.filter(user=user).afirst()
|
config = await GithubConfig.objects.filter(user=user).afirst()
|
||||||
|
|
||||||
|
@ -587,8 +639,11 @@ class AgentAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def adelete_agent_by_slug(agent_slug: str, user: KhojUser):
|
async def adelete_agent_by_slug(agent_slug: str, user: KhojUser):
|
||||||
agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
|
agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
|
||||||
|
if agent.creator != user:
|
||||||
|
return False
|
||||||
|
|
||||||
async for entry in Entry.objects.filter(agent=agent).aiterator():
|
async for entry in Entry.objects.filter(agent=agent).aiterator():
|
||||||
await entry.adelete()
|
await entry.adelete()
|
||||||
|
@ -712,6 +767,7 @@ class AgentAdapters:
|
||||||
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
|
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def aupdate_agent(
|
async def aupdate_agent(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
name: str,
|
name: str,
|
||||||
|
@ -787,19 +843,6 @@ class PublicConversationAdapters:
|
||||||
return f"/share/chat/{public_conversation.slug}/"
|
return f"/share/chat/{public_conversation.slug}/"
|
||||||
|
|
||||||
|
|
||||||
class DataStoreAdapters:
|
|
||||||
@staticmethod
|
|
||||||
async def astore_data(data: dict, key: str, user: KhojUser, private: bool = True):
|
|
||||||
if await DataStore.objects.filter(key=key).aexists():
|
|
||||||
return key
|
|
||||||
await DataStore.objects.acreate(value=data, key=key, owner=user, private=private)
|
|
||||||
return key
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def aretrieve_public_data(key: str):
|
|
||||||
return await DataStore.objects.filter(key=key, private=False).afirst()
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationAdapters:
|
class ConversationAdapters:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_public_conversation_copy(conversation: Conversation):
|
def make_public_conversation_copy(conversation: Conversation):
|
||||||
|
@ -812,6 +855,7 @@ class ConversationAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_conversation_by_user(
|
def get_conversation_by_user(
|
||||||
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
|
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
|
||||||
) -> Optional[Conversation]:
|
) -> Optional[Conversation]:
|
||||||
|
@ -830,6 +874,7 @@ class ConversationAdapters:
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None):
|
def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None):
|
||||||
return (
|
return (
|
||||||
Conversation.objects.filter(user=user, client=client_application)
|
Conversation.objects.filter(user=user, client=client_application)
|
||||||
|
@ -838,6 +883,7 @@ class ConversationAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def aset_conversation_title(
|
async def aset_conversation_title(
|
||||||
user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str
|
user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str
|
||||||
):
|
):
|
||||||
|
@ -855,6 +901,7 @@ class ConversationAdapters:
|
||||||
return Conversation.objects.filter(id=conversation_id).first()
|
return Conversation.objects.filter(id=conversation_id).first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def acreate_conversation_session(
|
async def acreate_conversation_session(
|
||||||
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
|
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
|
||||||
):
|
):
|
||||||
|
@ -871,6 +918,7 @@ class ConversationAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def create_conversation_session(
|
def create_conversation_session(
|
||||||
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
|
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
|
||||||
):
|
):
|
||||||
|
@ -883,6 +931,7 @@ class ConversationAdapters:
|
||||||
return Conversation.objects.create(user=user, client=client_application, agent=agent, title=title)
|
return Conversation.objects.create(user=user, client=client_application, agent=agent, title=title)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def aget_conversation_by_user(
|
async def aget_conversation_by_user(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
client_application: ClientApplication = None,
|
client_application: ClientApplication = None,
|
||||||
|
@ -907,6 +956,7 @@ class ConversationAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def adelete_conversation_by_user(
|
async def adelete_conversation_by_user(
|
||||||
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
|
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
|
||||||
):
|
):
|
||||||
|
@ -915,6 +965,7 @@ class ConversationAdapters:
|
||||||
return await Conversation.objects.filter(user=user, client=client_application).adelete()
|
return await Conversation.objects.filter(user=user, client=client_application).adelete()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def has_any_conversation_config(user: KhojUser):
|
def has_any_conversation_config(user: KhojUser):
|
||||||
return ChatModelOptions.objects.filter(user=user).exists()
|
return ChatModelOptions.objects.filter(user=user).exists()
|
||||||
|
|
||||||
|
@ -951,6 +1002,7 @@ class ConversationAdapters:
|
||||||
return OpenAIProcessorConversationConfig.objects.filter().exists()
|
return OpenAIProcessorConversationConfig.objects.filter().exists()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
|
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
|
||||||
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
|
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
|
||||||
if not config:
|
if not config:
|
||||||
|
@ -959,6 +1011,7 @@ class ConversationAdapters:
|
||||||
return new_config
|
return new_config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def aset_user_voice_model(user: KhojUser, model_id: str):
|
async def aset_user_voice_model(user: KhojUser, model_id: str):
|
||||||
config = await VoiceModelOption.objects.filter(model_id=model_id).afirst()
|
config = await VoiceModelOption.objects.filter(model_id=model_id).afirst()
|
||||||
if not config:
|
if not config:
|
||||||
|
@ -1143,6 +1196,7 @@ class ConversationAdapters:
|
||||||
return enabled_scrapers
|
return enabled_scrapers
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def create_conversation_from_public_conversation(
|
def create_conversation_from_public_conversation(
|
||||||
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication
|
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication
|
||||||
):
|
):
|
||||||
|
@ -1159,6 +1213,7 @@ class ConversationAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def save_conversation(
|
def save_conversation(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
conversation_log: dict,
|
conversation_log: dict,
|
||||||
|
@ -1208,6 +1263,7 @@ class ConversationAdapters:
|
||||||
return await SpeechToTextModelOptions.objects.filter().afirst()
|
return await SpeechToTextModelOptions.objects.filter().afirst()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def aget_conversation_starters(user: KhojUser, max_results=3):
|
async def aget_conversation_starters(user: KhojUser, max_results=3):
|
||||||
all_questions = []
|
all_questions = []
|
||||||
if await ReflectiveQuestion.objects.filter(user=user).aexists():
|
if await ReflectiveQuestion.objects.filter(user=user).aexists():
|
||||||
|
@ -1337,6 +1393,7 @@ class ConversationAdapters:
|
||||||
return conversation.file_filters
|
return conversation.file_filters
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str):
|
def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str):
|
||||||
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
|
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
|
||||||
if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"):
|
if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"):
|
||||||
|
@ -1355,52 +1412,63 @@ class FileObjectAdapters:
|
||||||
file_object.save()
|
file_object.save()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def create_file_object(user: KhojUser, file_name: str, raw_text: str):
|
def create_file_object(user: KhojUser, file_name: str, raw_text: str):
|
||||||
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
|
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_file_object_by_name(user: KhojUser, file_name: str):
|
def get_file_object_by_name(user: KhojUser, file_name: str):
|
||||||
return FileObject.objects.filter(user=user, file_name=file_name).first()
|
return FileObject.objects.filter(user=user, file_name=file_name).first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_all_file_objects(user: KhojUser):
|
def get_all_file_objects(user: KhojUser):
|
||||||
return FileObject.objects.filter(user=user).all()
|
return FileObject.objects.filter(user=user).all()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def delete_file_object_by_name(user: KhojUser, file_name: str):
|
def delete_file_object_by_name(user: KhojUser, file_name: str):
|
||||||
return FileObject.objects.filter(user=user, file_name=file_name).delete()
|
return FileObject.objects.filter(user=user, file_name=file_name).delete()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def delete_all_file_objects(user: KhojUser):
|
def delete_all_file_objects(user: KhojUser):
|
||||||
return FileObject.objects.filter(user=user).delete()
|
return FileObject.objects.filter(user=user).delete()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def async_update_raw_text(file_object: FileObject, new_raw_text: str):
|
async def aupdate_raw_text(file_object: FileObject, new_raw_text: str):
|
||||||
file_object.raw_text = new_raw_text
|
file_object.raw_text = new_raw_text
|
||||||
await file_object.asave()
|
await file_object.asave()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str):
|
@arequire_valid_user
|
||||||
|
async def acreate_file_object(user: KhojUser, file_name: str, raw_text: str):
|
||||||
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
|
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
|
@arequire_valid_user
|
||||||
|
async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
|
||||||
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
|
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def async_get_file_objects_by_names(user: KhojUser, file_names: List[str]):
|
@arequire_valid_user
|
||||||
|
async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]):
|
||||||
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names))
|
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def async_get_all_file_objects(user: KhojUser):
|
@arequire_valid_user
|
||||||
|
async def aget_all_file_objects(user: KhojUser):
|
||||||
return await sync_to_async(list)(FileObject.objects.filter(user=user))
|
return await sync_to_async(list)(FileObject.objects.filter(user=user))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def async_delete_file_object_by_name(user: KhojUser, file_name: str):
|
@arequire_valid_user
|
||||||
|
async def adelete_file_object_by_name(user: KhojUser, file_name: str):
|
||||||
return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
|
return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def async_delete_all_file_objects(user: KhojUser):
|
@arequire_valid_user
|
||||||
|
async def adelete_all_file_objects(user: KhojUser):
|
||||||
return await FileObject.objects.filter(user=user).adelete()
|
return await FileObject.objects.filter(user=user).adelete()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1410,15 +1478,18 @@ class EntryAdapters:
|
||||||
date_filter = DateFilter()
|
date_filter = DateFilter()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
|
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
|
||||||
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
|
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def delete_entry_by_file(user: KhojUser, file_path: str):
|
def delete_entry_by_file(user: KhojUser, file_path: str):
|
||||||
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
|
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None):
|
def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None):
|
||||||
queryset = Entry.objects.filter(user=user)
|
queryset = Entry.objects.filter(user=user)
|
||||||
|
|
||||||
|
@ -1431,6 +1502,7 @@ class EntryAdapters:
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
|
def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
|
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
|
||||||
|
@ -1442,6 +1514,7 @@ class EntryAdapters:
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
|
async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
|
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
|
||||||
|
@ -1453,10 +1526,12 @@ class EntryAdapters:
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
|
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
|
||||||
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
|
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
|
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
|
||||||
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
|
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
|
||||||
|
|
||||||
|
@ -1468,6 +1543,7 @@ class EntryAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def user_has_entries(user: KhojUser):
|
def user_has_entries(user: KhojUser):
|
||||||
return Entry.objects.filter(user=user).exists()
|
return Entry.objects.filter(user=user).exists()
|
||||||
|
|
||||||
|
@ -1476,6 +1552,7 @@ class EntryAdapters:
|
||||||
return Entry.objects.filter(agent=agent).exists()
|
return Entry.objects.filter(agent=agent).exists()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def auser_has_entries(user: KhojUser):
|
async def auser_has_entries(user: KhojUser):
|
||||||
return await Entry.objects.filter(user=user).aexists()
|
return await Entry.objects.filter(user=user).aexists()
|
||||||
|
|
||||||
|
@ -1486,10 +1563,12 @@ class EntryAdapters:
|
||||||
return await Entry.objects.filter(agent=agent).aexists()
|
return await Entry.objects.filter(agent=agent).aexists()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def adelete_entry_by_file(user: KhojUser, file_path: str):
|
async def adelete_entry_by_file(user: KhojUser, file_path: str):
|
||||||
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
|
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000):
|
async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000):
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
for i in range(0, len(filenames), batch_size):
|
for i in range(0, len(filenames), batch_size):
|
||||||
|
@ -1508,6 +1587,7 @@ class EntryAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_all_filenames_by_source(user: KhojUser, file_source: str):
|
def get_all_filenames_by_source(user: KhojUser, file_source: str):
|
||||||
return (
|
return (
|
||||||
Entry.objects.filter(user=user, file_source=file_source)
|
Entry.objects.filter(user=user, file_source=file_source)
|
||||||
|
@ -1516,6 +1596,7 @@ class EntryAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_size_of_indexed_data_in_mb(user: KhojUser):
|
def get_size_of_indexed_data_in_mb(user: KhojUser):
|
||||||
entries = Entry.objects.filter(user=user).iterator()
|
entries = Entry.objects.filter(user=user).iterator()
|
||||||
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
|
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
|
||||||
|
@ -1536,6 +1617,9 @@ class EntryAdapters:
|
||||||
if agent != None:
|
if agent != None:
|
||||||
owner_filter |= Q(agent=agent)
|
owner_filter |= Q(agent=agent)
|
||||||
|
|
||||||
|
if owner_filter == Q():
|
||||||
|
return Entry.objects.none()
|
||||||
|
|
||||||
if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
|
if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
|
||||||
return Entry.objects.filter(owner_filter)
|
return Entry.objects.filter(owner_filter)
|
||||||
|
|
||||||
|
@ -1610,10 +1694,12 @@ class EntryAdapters:
|
||||||
return relevant_entries[:max_results]
|
return relevant_entries[:max_results]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_unique_file_types(user: KhojUser):
|
def get_unique_file_types(user: KhojUser):
|
||||||
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
|
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@require_valid_user
|
||||||
def get_unique_file_sources(user: KhojUser):
|
def get_unique_file_sources(user: KhojUser):
|
||||||
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all()
|
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all()
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ class DocxToEntries(TextToEntries):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||||
files_to_process = set(files) - deletion_file_names
|
files_to_process = set(files) - deletion_file_names
|
||||||
|
@ -35,13 +35,13 @@ class DocxToEntries(TextToEntries):
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
|
user,
|
||||||
current_entries,
|
current_entries,
|
||||||
DbEntry.EntryType.DOCX,
|
DbEntry.EntryType.DOCX,
|
||||||
DbEntry.EntrySource.COMPUTER,
|
DbEntry.EntrySource.COMPUTER,
|
||||||
"compiled",
|
"compiled",
|
||||||
logger,
|
logger,
|
||||||
deletion_file_names,
|
deletion_file_names,
|
||||||
user,
|
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
file_to_text_map=file_to_text_map,
|
file_to_text_map=file_to_text_map,
|
||||||
)
|
)
|
||||||
|
|
|
@ -48,7 +48,7 @@ class GithubToEntries(TextToEntries):
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
if self.config.pat_token is None or self.config.pat_token == "":
|
if self.config.pat_token is None or self.config.pat_token == "":
|
||||||
logger.error(f"Github PAT token is not set. Skipping github content")
|
logger.error(f"Github PAT token is not set. Skipping github content")
|
||||||
raise ValueError("Github PAT token is not set. Skipping github content")
|
raise ValueError("Github PAT token is not set. Skipping github content")
|
||||||
|
@ -101,12 +101,12 @@ class GithubToEntries(TextToEntries):
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
|
user,
|
||||||
current_entries,
|
current_entries,
|
||||||
DbEntry.EntryType.GITHUB,
|
DbEntry.EntryType.GITHUB,
|
||||||
DbEntry.EntrySource.GITHUB,
|
DbEntry.EntrySource.GITHUB,
|
||||||
key="compiled",
|
key="compiled",
|
||||||
logger=logger,
|
logger=logger,
|
||||||
user=user,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return num_new_embeddings, num_deleted_embeddings
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
|
|
|
@ -18,7 +18,7 @@ class ImageToEntries(TextToEntries):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||||
files_to_process = set(files) - deletion_file_names
|
files_to_process = set(files) - deletion_file_names
|
||||||
|
@ -35,13 +35,13 @@ class ImageToEntries(TextToEntries):
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
|
user,
|
||||||
current_entries,
|
current_entries,
|
||||||
DbEntry.EntryType.IMAGE,
|
DbEntry.EntryType.IMAGE,
|
||||||
DbEntry.EntrySource.COMPUTER,
|
DbEntry.EntrySource.COMPUTER,
|
||||||
"compiled",
|
"compiled",
|
||||||
logger,
|
logger,
|
||||||
deletion_file_names,
|
deletion_file_names,
|
||||||
user,
|
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
file_to_text_map=file_to_text_map,
|
file_to_text_map=file_to_text_map,
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,7 +19,7 @@ class MarkdownToEntries(TextToEntries):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||||
files_to_process = set(files) - deletion_file_names
|
files_to_process = set(files) - deletion_file_names
|
||||||
|
@ -37,13 +37,13 @@ class MarkdownToEntries(TextToEntries):
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
|
user,
|
||||||
current_entries,
|
current_entries,
|
||||||
DbEntry.EntryType.MARKDOWN,
|
DbEntry.EntryType.MARKDOWN,
|
||||||
DbEntry.EntrySource.COMPUTER,
|
DbEntry.EntrySource.COMPUTER,
|
||||||
"compiled",
|
"compiled",
|
||||||
logger,
|
logger,
|
||||||
deletion_file_names,
|
deletion_file_names,
|
||||||
user,
|
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
file_to_text_map=file_to_text_map,
|
file_to_text_map=file_to_text_map,
|
||||||
)
|
)
|
||||||
|
|
|
@ -79,7 +79,7 @@ class NotionToEntries(TextToEntries):
|
||||||
|
|
||||||
self.body_params = {"page_size": 100}
|
self.body_params = {"page_size": 100}
|
||||||
|
|
||||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
current_entries = []
|
current_entries = []
|
||||||
|
|
||||||
# Get all pages
|
# Get all pages
|
||||||
|
@ -248,12 +248,12 @@ class NotionToEntries(TextToEntries):
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
|
user,
|
||||||
current_entries,
|
current_entries,
|
||||||
DbEntry.EntryType.NOTION,
|
DbEntry.EntryType.NOTION,
|
||||||
DbEntry.EntrySource.NOTION,
|
DbEntry.EntrySource.NOTION,
|
||||||
key="compiled",
|
key="compiled",
|
||||||
logger=logger,
|
logger=logger,
|
||||||
user=user,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return num_new_embeddings, num_deleted_embeddings
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
|
|
|
@ -20,7 +20,7 @@ class OrgToEntries(TextToEntries):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||||
files_to_process = set(files) - deletion_file_names
|
files_to_process = set(files) - deletion_file_names
|
||||||
files = {file: files[file] for file in files_to_process}
|
files = {file: files[file] for file in files_to_process}
|
||||||
|
@ -36,13 +36,13 @@ class OrgToEntries(TextToEntries):
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
|
user,
|
||||||
current_entries,
|
current_entries,
|
||||||
DbEntry.EntryType.ORG,
|
DbEntry.EntryType.ORG,
|
||||||
DbEntry.EntrySource.COMPUTER,
|
DbEntry.EntrySource.COMPUTER,
|
||||||
"compiled",
|
"compiled",
|
||||||
logger,
|
logger,
|
||||||
deletion_file_names,
|
deletion_file_names,
|
||||||
user,
|
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
file_to_text_map=file_to_text_map,
|
file_to_text_map=file_to_text_map,
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,7 +19,7 @@ class PdfToEntries(TextToEntries):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||||
files_to_process = set(files) - deletion_file_names
|
files_to_process = set(files) - deletion_file_names
|
||||||
|
@ -36,13 +36,13 @@ class PdfToEntries(TextToEntries):
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
|
user,
|
||||||
current_entries,
|
current_entries,
|
||||||
DbEntry.EntryType.PDF,
|
DbEntry.EntryType.PDF,
|
||||||
DbEntry.EntrySource.COMPUTER,
|
DbEntry.EntrySource.COMPUTER,
|
||||||
"compiled",
|
"compiled",
|
||||||
logger,
|
logger,
|
||||||
deletion_file_names,
|
deletion_file_names,
|
||||||
user,
|
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
file_to_text_map=file_to_text_map,
|
file_to_text_map=file_to_text_map,
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,7 +20,7 @@ class PlaintextToEntries(TextToEntries):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||||
files_to_process = set(files) - deletion_file_names
|
files_to_process = set(files) - deletion_file_names
|
||||||
files = {file: files[file] for file in files_to_process}
|
files = {file: files[file] for file in files_to_process}
|
||||||
|
@ -36,13 +36,13 @@ class PlaintextToEntries(TextToEntries):
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
|
user,
|
||||||
current_entries,
|
current_entries,
|
||||||
DbEntry.EntryType.PLAINTEXT,
|
DbEntry.EntryType.PLAINTEXT,
|
||||||
DbEntry.EntrySource.COMPUTER,
|
DbEntry.EntrySource.COMPUTER,
|
||||||
key="compiled",
|
key="compiled",
|
||||||
logger=logger,
|
logger=logger,
|
||||||
deletion_filenames=deletion_file_names,
|
deletion_filenames=deletion_file_names,
|
||||||
user=user,
|
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
file_to_text_map=file_to_text_map,
|
file_to_text_map=file_to_text_map,
|
||||||
)
|
)
|
||||||
|
|
|
@ -31,7 +31,7 @@ class TextToEntries(ABC):
|
||||||
self.date_filter = DateFilter()
|
self.date_filter = DateFilter()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -114,13 +114,13 @@ class TextToEntries(ABC):
|
||||||
|
|
||||||
def update_embeddings(
|
def update_embeddings(
|
||||||
self,
|
self,
|
||||||
|
user: KhojUser,
|
||||||
current_entries: List[Entry],
|
current_entries: List[Entry],
|
||||||
file_type: str,
|
file_type: str,
|
||||||
file_source: str,
|
file_source: str,
|
||||||
key="compiled",
|
key="compiled",
|
||||||
logger: logging.Logger = None,
|
logger: logging.Logger = None,
|
||||||
deletion_filenames: Set[str] = None,
|
deletion_filenames: Set[str] = None,
|
||||||
user: KhojUser = None,
|
|
||||||
regenerate: bool = False,
|
regenerate: bool = False,
|
||||||
file_to_text_map: dict[str, str] = None,
|
file_to_text_map: dict[str, str] = None,
|
||||||
):
|
):
|
||||||
|
|
|
@ -212,7 +212,7 @@ def update(
|
||||||
logger.warning(error_msg)
|
logger.warning(error_msg)
|
||||||
raise HTTPException(status_code=500, detail=error_msg)
|
raise HTTPException(status_code=500, detail=error_msg)
|
||||||
try:
|
try:
|
||||||
initialize_content(regenerate=force, search_type=t, user=user)
|
initialize_content(user=user, regenerate=force, search_type=t)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"🚨 Failed to update server via API: {e}"
|
error_msg = f"🚨 Failed to update server via API: {e}"
|
||||||
logger.error(error_msg, exc_info=True)
|
logger.error(error_msg, exc_info=True)
|
||||||
|
|
|
@ -239,7 +239,7 @@ async def set_content_notion(
|
||||||
|
|
||||||
if updated_config.token:
|
if updated_config.token:
|
||||||
# Trigger an async job to configure_content. Let it run without blocking the response.
|
# Trigger an async job to configure_content. Let it run without blocking the response.
|
||||||
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user)
|
background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion)
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -512,10 +512,10 @@ async def indexer(
|
||||||
success = await loop.run_in_executor(
|
success = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
configure_content,
|
configure_content,
|
||||||
|
user,
|
||||||
indexer_input.model_dump(),
|
indexer_input.model_dump(),
|
||||||
regenerate,
|
regenerate,
|
||||||
t,
|
t,
|
||||||
user,
|
|
||||||
)
|
)
|
||||||
if not success:
|
if not success:
|
||||||
raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index")
|
raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index")
|
||||||
|
|
|
@ -703,7 +703,7 @@ async def generate_summary_from_files(
|
||||||
if await EntryAdapters.aagent_has_entries(agent):
|
if await EntryAdapters.aagent_has_entries(agent):
|
||||||
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
|
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
|
||||||
if len(file_names) > 0:
|
if len(file_names) > 0:
|
||||||
file_objects = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent)
|
file_objects = await FileObjectAdapters.aget_file_objects_by_name(None, file_names.pop(), agent)
|
||||||
|
|
||||||
if (file_objects and len(file_objects) == 0 and not query_files) or (not file_objects and not query_files):
|
if (file_objects and len(file_objects) == 0 and not query_files) or (not file_objects and not query_files):
|
||||||
response_log = "Sorry, I couldn't find anything to summarize."
|
response_log = "Sorry, I couldn't find anything to summarize."
|
||||||
|
@ -1975,10 +1975,10 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
||||||
|
|
||||||
|
|
||||||
def configure_content(
|
def configure_content(
|
||||||
|
user: KhojUser,
|
||||||
files: Optional[dict[str, dict[str, str]]],
|
files: Optional[dict[str, dict[str, str]]],
|
||||||
regenerate: bool = False,
|
regenerate: bool = False,
|
||||||
t: Optional[state.SearchType] = state.SearchType.All,
|
t: Optional[state.SearchType] = state.SearchType.All,
|
||||||
user: KhojUser = None,
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
success = True
|
success = True
|
||||||
if t == None:
|
if t == None:
|
||||||
|
|
|
@ -80,6 +80,6 @@ async def notion_auth_callback(request: Request, background_tasks: BackgroundTas
|
||||||
notion_redirect = str(request.app.url_path_for("config_page"))
|
notion_redirect = str(request.app.url_path_for("config_page"))
|
||||||
|
|
||||||
# Trigger an async job to configure_content. Let it run without blocking the response.
|
# Trigger an async job to configure_content. Let it run without blocking the response.
|
||||||
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user)
|
background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion)
|
||||||
|
|
||||||
return RedirectResponse(notion_redirect)
|
return RedirectResponse(notion_redirect)
|
||||||
|
|
|
@ -208,7 +208,7 @@ def setup(
|
||||||
text_to_entries: Type[TextToEntries],
|
text_to_entries: Type[TextToEntries],
|
||||||
files: dict[str, str],
|
files: dict[str, str],
|
||||||
regenerate: bool,
|
regenerate: bool,
|
||||||
user: KhojUser = None,
|
user: KhojUser,
|
||||||
config=None,
|
config=None,
|
||||||
) -> Tuple[int, int]:
|
) -> Tuple[int, int]:
|
||||||
if config:
|
if config:
|
||||||
|
|
|
@ -8,6 +8,7 @@ from bs4 import BeautifulSoup
|
||||||
from magika import Magika
|
from magika import Magika
|
||||||
|
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
|
KhojUser,
|
||||||
LocalMarkdownConfig,
|
LocalMarkdownConfig,
|
||||||
LocalOrgConfig,
|
LocalOrgConfig,
|
||||||
LocalPdfConfig,
|
LocalPdfConfig,
|
||||||
|
@ -21,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||||
magika = Magika()
|
magika = Magika()
|
||||||
|
|
||||||
|
|
||||||
def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict:
|
def collect_files(user: KhojUser, search_type: Optional[SearchType] = SearchType.All) -> dict:
|
||||||
files: dict[str, dict] = {"docx": {}, "image": {}}
|
files: dict[str, dict] = {"docx": {}, "image": {}}
|
||||||
|
|
||||||
if search_type == SearchType.All or search_type == SearchType.Org:
|
if search_type == SearchType.All or search_type == SearchType.Org:
|
||||||
|
|
|
@ -304,7 +304,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
|
||||||
|
|
||||||
# Index Markdown Content for Search
|
# Index Markdown Content for Search
|
||||||
all_files = fs_syncer.collect_files(user=user)
|
all_files = fs_syncer.collect_files(user=user)
|
||||||
success = configure_content(all_files, user=user)
|
configure_content(user, all_files)
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
if os.getenv("OPENAI_API_KEY"):
|
if os.getenv("OPENAI_API_KEY"):
|
||||||
|
@ -381,7 +381,7 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
|
||||||
)
|
)
|
||||||
|
|
||||||
all_files = fs_syncer.collect_files(user=default_user2)
|
all_files = fs_syncer.collect_files(user=default_user2)
|
||||||
configure_content(all_files, user=default_user2)
|
configure_content(default_user2, all_files)
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
ChatModelOptionsFactory(
|
ChatModelOptionsFactory(
|
||||||
|
@ -432,7 +432,7 @@ def pdf_configured_user1(default_user: KhojUser):
|
||||||
)
|
)
|
||||||
# Index Markdown Content for Search
|
# Index Markdown Content for Search
|
||||||
all_files = fs_syncer.collect_files(user=default_user)
|
all_files = fs_syncer.collect_files(user=default_user)
|
||||||
success = configure_content(all_files, user=default_user)
|
configure_content(default_user, all_files)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
|
|
|
@ -253,11 +253,11 @@ def test_regenerate_with_github_fails_without_pat(client):
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_get_configured_types_via_api(client, sample_org_data):
|
def test_get_configured_types_via_api(client, sample_org_data, default_user3: KhojUser):
|
||||||
# Act
|
# Act
|
||||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False)
|
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user3)
|
||||||
|
|
||||||
enabled_types = EntryAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
|
enabled_types = EntryAdapters.get_unique_file_types(user=default_user3).all().values_list("file_type", flat=True)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert list(enabled_types) == ["org"]
|
assert list(enabled_types) == ["org"]
|
||||||
|
|
Loading…
Reference in a new issue