Do not CRUD on entries, files & conversations in DB for null user (#958)

Increase defense-in-depth by reducing paths to create, read, update or delete entries, files and conversations in DB when user is unset.
This commit is contained in:
Debanjum 2024-11-11 12:47:22 -08:00 committed by GitHub
commit ba2471dc02
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 158 additions and 73 deletions

View file

@ -253,7 +253,7 @@ def configure_server(
logger.info(message) logger.info(message)
if not init: if not init:
initialize_content(regenerate, search_type, user) initialize_content(user, regenerate, search_type)
except Exception as e: except Exception as e:
logger.error(f"Failed to load some search models: {e}", exc_info=True) logger.error(f"Failed to load some search models: {e}", exc_info=True)
@ -263,17 +263,17 @@ def setup_default_agent(user: KhojUser):
AgentAdapters.create_default_agent(user) AgentAdapters.create_default_agent(user)
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None): def initialize_content(user: KhojUser, regenerate: bool, search_type: Optional[SearchType] = None):
# Initialize Content from Config # Initialize Content from Config
if state.search_models: if state.search_models:
try: try:
logger.info("📬 Updating content index...") logger.info("📬 Updating content index...")
all_files = collect_files(user=user) all_files = collect_files(user=user)
status = configure_content( status = configure_content(
user,
all_files, all_files,
regenerate, regenerate,
search_type, search_type,
user=user,
) )
if not status: if not status:
raise RuntimeError("Failed to update content index") raise RuntimeError("Failed to update content index")
@ -338,9 +338,7 @@ def configure_middleware(app):
def update_content_index(): def update_content_index():
for user in get_all_users(): for user in get_all_users():
all_files = collect_files(user=user) all_files = collect_files(user=user)
success = configure_content(all_files, user=user) success = configure_content(user, all_files)
all_files = collect_files(user=None)
success = configure_content(all_files, user=None)
if not success: if not success:
raise RuntimeError("Failed to update content index") raise RuntimeError("Failed to update content index")
logger.info("📪 Content index updated via Scheduler") logger.info("📪 Content index updated via Scheduler")

View file

@ -8,13 +8,22 @@ import secrets
import sys import sys
from datetime import date, datetime, timedelta, timezone from datetime import date, datetime, timedelta, timezone
from enum import Enum from enum import Enum
from typing import Callable, Iterable, List, Optional, Type from functools import wraps
from typing import (
Any,
Callable,
Coroutine,
Iterable,
List,
Optional,
ParamSpec,
TypeVar,
)
import cron_descriptor import cron_descriptor
from apscheduler.job import Job from apscheduler.job import Job
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.db import SessionStore from django.contrib.sessions.backends.db import SessionStore
from django.db import models
from django.db.models import Prefetch, Q from django.db.models import Prefetch, Q
from django.db.models.manager import BaseManager from django.db.models.manager import BaseManager
from django.db.utils import IntegrityError from django.db.utils import IntegrityError
@ -28,7 +37,6 @@ from khoj.database.models import (
ChatModelOptions, ChatModelOptions,
ClientApplication, ClientApplication,
Conversation, Conversation,
DataStore,
Entry, Entry,
FileObject, FileObject,
GithubConfig, GithubConfig,
@ -80,6 +88,45 @@ class SubscriptionState(Enum):
INVALID = "invalid" INVALID = "invalid"
P = ParamSpec("P")
T = TypeVar("T")
def require_valid_user(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# Extract user from args/kwargs
user = next((arg for arg in args if isinstance(arg, KhojUser)), None)
if not user:
user = next((val for val in kwargs.values() if isinstance(val, KhojUser)), None)
# Throw error if user is not found
if not user:
raise ValueError("Khoj user argument required but not provided.")
return func(*args, **kwargs)
return sync_wrapper
def arequire_valid_user(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Coroutine[Any, Any, T]]:
@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# Extract user from args/kwargs
user = next((arg for arg in args if isinstance(arg, KhojUser)), None)
if not user:
user = next((v for v in kwargs.values() if isinstance(v, KhojUser)), None)
# Throw error if user is not found
if not user:
raise ValueError("Khoj user argument required but not provided.")
return await func(*args, **kwargs)
return async_wrapper
@arequire_valid_user
async def set_notion_config(token: str, user: KhojUser): async def set_notion_config(token: str, user: KhojUser):
notion_config = await NotionConfig.objects.filter(user=user).afirst() notion_config = await NotionConfig.objects.filter(user=user).afirst()
if not notion_config: if not notion_config:
@ -90,6 +137,7 @@ async def set_notion_config(token: str, user: KhojUser):
return notion_config return notion_config
@require_valid_user
def create_khoj_token(user: KhojUser, name=None): def create_khoj_token(user: KhojUser, name=None):
"Create Khoj API key for user" "Create Khoj API key for user"
token = f"kk-{secrets.token_urlsafe(32)}" token = f"kk-{secrets.token_urlsafe(32)}"
@ -97,6 +145,7 @@ def create_khoj_token(user: KhojUser, name=None):
return KhojApiUser.objects.create(token=token, user=user, name=name) return KhojApiUser.objects.create(token=token, user=user, name=name)
@arequire_valid_user
async def acreate_khoj_token(user: KhojUser, name=None): async def acreate_khoj_token(user: KhojUser, name=None):
"Create Khoj API key for user" "Create Khoj API key for user"
token = f"kk-{secrets.token_urlsafe(32)}" token = f"kk-{secrets.token_urlsafe(32)}"
@ -104,11 +153,13 @@ async def acreate_khoj_token(user: KhojUser, name=None):
return await KhojApiUser.objects.acreate(token=token, user=user, name=name) return await KhojApiUser.objects.acreate(token=token, user=user, name=name)
@require_valid_user
def get_khoj_tokens(user: KhojUser): def get_khoj_tokens(user: KhojUser):
"Get all Khoj API keys for user" "Get all Khoj API keys for user"
return list(KhojApiUser.objects.filter(user=user)) return list(KhojApiUser.objects.filter(user=user))
@arequire_valid_user
async def delete_khoj_token(user: KhojUser, token: str): async def delete_khoj_token(user: KhojUser, token: str):
"Delete Khoj API Key for user" "Delete Khoj API Key for user"
await KhojApiUser.objects.filter(token=token, user=user).adelete() await KhojApiUser.objects.filter(token=token, user=user).adelete()
@ -132,6 +183,7 @@ async def aget_or_create_user_by_phone_number(phone_number: str) -> tuple[KhojUs
return user, is_new return user, is_new
@arequire_valid_user
async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser: async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number): if is_none_or_empty(phone_number):
return None return None
@ -155,6 +207,7 @@ async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
return user return user
@arequire_valid_user
async def aremove_phone_number(user: KhojUser) -> KhojUser: async def aremove_phone_number(user: KhojUser) -> KhojUser:
user.phone_number = None user.phone_number = None
user.verified_phone_number = False user.verified_phone_number = False
@ -192,6 +245,7 @@ async def aget_or_create_user_by_email(email: str) -> tuple[KhojUser, bool]:
return user, is_new return user, is_new
@arequire_valid_user
async def astart_trial_subscription(user: KhojUser) -> Subscription: async def astart_trial_subscription(user: KhojUser) -> Subscription:
subscription = await Subscription.objects.filter(user=user).afirst() subscription = await Subscription.objects.filter(user=user).afirst()
if not subscription: if not subscription:
@ -246,6 +300,7 @@ async def create_user_by_google_token(token: dict) -> KhojUser:
return user return user
@require_valid_user
def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser: def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
user.first_name = first_name user.first_name = first_name
user.last_name = last_name user.last_name = last_name
@ -253,6 +308,7 @@ def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
return user return user
@require_valid_user
def get_user_name(user: KhojUser): def get_user_name(user: KhojUser):
full_name = user.get_full_name() full_name = user.get_full_name()
if not is_none_or_empty(full_name): if not is_none_or_empty(full_name):
@ -264,6 +320,7 @@ def get_user_name(user: KhojUser):
return None return None
@require_valid_user
def get_user_photo(user: KhojUser): def get_user_photo(user: KhojUser):
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first() google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
if google_profile: if google_profile:
@ -327,6 +384,7 @@ def get_user_subscription_state(email: str) -> str:
return subscription_to_state(user_subscription) return subscription_to_state(user_subscription)
@arequire_valid_user
async def aget_user_subscription_state(user: KhojUser) -> str: async def aget_user_subscription_state(user: KhojUser) -> str:
"""Get subscription state of user """Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
@ -335,6 +393,7 @@ async def aget_user_subscription_state(user: KhojUser) -> str:
return await sync_to_async(subscription_to_state)(user_subscription) return await sync_to_async(subscription_to_state)(user_subscription)
@arequire_valid_user
async def ais_user_subscribed(user: KhojUser) -> bool: async def ais_user_subscribed(user: KhojUser) -> bool:
""" """
Get whether the user is subscribed Get whether the user is subscribed
@ -351,6 +410,7 @@ async def ais_user_subscribed(user: KhojUser) -> bool:
return subscribed return subscribed
@require_valid_user
def is_user_subscribed(user: KhojUser) -> bool: def is_user_subscribed(user: KhojUser) -> bool:
""" """
Get whether the user is subscribed Get whether the user is subscribed
@ -416,11 +476,13 @@ def get_all_users() -> BaseManager[KhojUser]:
return KhojUser.objects.all() return KhojUser.objects.all()
@require_valid_user
def get_user_github_config(user: KhojUser): def get_user_github_config(user: KhojUser):
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
return config return config
@require_valid_user
def get_user_notion_config(user: KhojUser): def get_user_notion_config(user: KhojUser):
config = NotionConfig.objects.filter(user=user).first() config = NotionConfig.objects.filter(user=user).first()
return config return config
@ -430,6 +492,7 @@ def delete_user_requests(window: timedelta = timedelta(days=1)):
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete() return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
@arequire_valid_user
async def aget_user_name(user: KhojUser): async def aget_user_name(user: KhojUser):
full_name = user.get_full_name() full_name = user.get_full_name()
if not is_none_or_empty(full_name): if not is_none_or_empty(full_name):
@ -441,18 +504,7 @@ async def aget_user_name(user: KhojUser):
return None return None
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config): @arequire_valid_user
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
await object.objects.filter(user=user).adelete()
await object.objects.acreate(
input_files=deduped_files,
input_filter=deduped_filters,
index_heading_entries=updated_config.index_heading_entries,
user=user,
)
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
config = await GithubConfig.objects.filter(user=user).afirst() config = await GithubConfig.objects.filter(user=user).afirst()
@ -587,8 +639,11 @@ class AgentAdapters:
) )
@staticmethod @staticmethod
@arequire_valid_user
async def adelete_agent_by_slug(agent_slug: str, user: KhojUser): async def adelete_agent_by_slug(agent_slug: str, user: KhojUser):
agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user) agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
if agent.creator != user:
return False
async for entry in Entry.objects.filter(agent=agent).aiterator(): async for entry in Entry.objects.filter(agent=agent).aiterator():
await entry.adelete() await entry.adelete()
@ -712,6 +767,7 @@ class AgentAdapters:
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst() return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
@staticmethod @staticmethod
@arequire_valid_user
async def aupdate_agent( async def aupdate_agent(
user: KhojUser, user: KhojUser,
name: str, name: str,
@ -787,19 +843,6 @@ class PublicConversationAdapters:
return f"/share/chat/{public_conversation.slug}/" return f"/share/chat/{public_conversation.slug}/"
class DataStoreAdapters:
@staticmethod
async def astore_data(data: dict, key: str, user: KhojUser, private: bool = True):
if await DataStore.objects.filter(key=key).aexists():
return key
await DataStore.objects.acreate(value=data, key=key, owner=user, private=private)
return key
@staticmethod
async def aretrieve_public_data(key: str):
return await DataStore.objects.filter(key=key, private=False).afirst()
class ConversationAdapters: class ConversationAdapters:
@staticmethod @staticmethod
def make_public_conversation_copy(conversation: Conversation): def make_public_conversation_copy(conversation: Conversation):
@ -812,6 +855,7 @@ class ConversationAdapters:
) )
@staticmethod @staticmethod
@require_valid_user
def get_conversation_by_user( def get_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
) -> Optional[Conversation]: ) -> Optional[Conversation]:
@ -830,6 +874,7 @@ class ConversationAdapters:
return conversation return conversation
@staticmethod @staticmethod
@require_valid_user
def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None): def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None):
return ( return (
Conversation.objects.filter(user=user, client=client_application) Conversation.objects.filter(user=user, client=client_application)
@ -838,6 +883,7 @@ class ConversationAdapters:
) )
@staticmethod @staticmethod
@arequire_valid_user
async def aset_conversation_title( async def aset_conversation_title(
user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str
): ):
@ -855,6 +901,7 @@ class ConversationAdapters:
return Conversation.objects.filter(id=conversation_id).first() return Conversation.objects.filter(id=conversation_id).first()
@staticmethod @staticmethod
@arequire_valid_user
async def acreate_conversation_session( async def acreate_conversation_session(
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
): ):
@ -871,6 +918,7 @@ class ConversationAdapters:
) )
@staticmethod @staticmethod
@require_valid_user
def create_conversation_session( def create_conversation_session(
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
): ):
@ -883,6 +931,7 @@ class ConversationAdapters:
return Conversation.objects.create(user=user, client=client_application, agent=agent, title=title) return Conversation.objects.create(user=user, client=client_application, agent=agent, title=title)
@staticmethod @staticmethod
@arequire_valid_user
async def aget_conversation_by_user( async def aget_conversation_by_user(
user: KhojUser, user: KhojUser,
client_application: ClientApplication = None, client_application: ClientApplication = None,
@ -907,6 +956,7 @@ class ConversationAdapters:
) )
@staticmethod @staticmethod
@arequire_valid_user
async def adelete_conversation_by_user( async def adelete_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
): ):
@ -915,6 +965,7 @@ class ConversationAdapters:
return await Conversation.objects.filter(user=user, client=client_application).adelete() return await Conversation.objects.filter(user=user, client=client_application).adelete()
@staticmethod @staticmethod
@require_valid_user
def has_any_conversation_config(user: KhojUser): def has_any_conversation_config(user: KhojUser):
return ChatModelOptions.objects.filter(user=user).exists() return ChatModelOptions.objects.filter(user=user).exists()
@ -951,6 +1002,7 @@ class ConversationAdapters:
return OpenAIProcessorConversationConfig.objects.filter().exists() return OpenAIProcessorConversationConfig.objects.filter().exists()
@staticmethod @staticmethod
@arequire_valid_user
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int): async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst() config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
if not config: if not config:
@ -959,6 +1011,7 @@ class ConversationAdapters:
return new_config return new_config
@staticmethod @staticmethod
@arequire_valid_user
async def aset_user_voice_model(user: KhojUser, model_id: str): async def aset_user_voice_model(user: KhojUser, model_id: str):
config = await VoiceModelOption.objects.filter(model_id=model_id).afirst() config = await VoiceModelOption.objects.filter(model_id=model_id).afirst()
if not config: if not config:
@ -1143,6 +1196,7 @@ class ConversationAdapters:
return enabled_scrapers return enabled_scrapers
@staticmethod @staticmethod
@require_valid_user
def create_conversation_from_public_conversation( def create_conversation_from_public_conversation(
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication
): ):
@ -1159,6 +1213,7 @@ class ConversationAdapters:
) )
@staticmethod @staticmethod
@require_valid_user
def save_conversation( def save_conversation(
user: KhojUser, user: KhojUser,
conversation_log: dict, conversation_log: dict,
@ -1208,6 +1263,7 @@ class ConversationAdapters:
return await SpeechToTextModelOptions.objects.filter().afirst() return await SpeechToTextModelOptions.objects.filter().afirst()
@staticmethod @staticmethod
@arequire_valid_user
async def aget_conversation_starters(user: KhojUser, max_results=3): async def aget_conversation_starters(user: KhojUser, max_results=3):
all_questions = [] all_questions = []
if await ReflectiveQuestion.objects.filter(user=user).aexists(): if await ReflectiveQuestion.objects.filter(user=user).aexists():
@ -1337,6 +1393,7 @@ class ConversationAdapters:
return conversation.file_filters return conversation.file_filters
@staticmethod @staticmethod
@require_valid_user
def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str): def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str):
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id) conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"): if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"):
@ -1355,52 +1412,63 @@ class FileObjectAdapters:
file_object.save() file_object.save()
@staticmethod @staticmethod
@require_valid_user
def create_file_object(user: KhojUser, file_name: str, raw_text: str): def create_file_object(user: KhojUser, file_name: str, raw_text: str):
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text) return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod @staticmethod
@require_valid_user
def get_file_object_by_name(user: KhojUser, file_name: str): def get_file_object_by_name(user: KhojUser, file_name: str):
return FileObject.objects.filter(user=user, file_name=file_name).first() return FileObject.objects.filter(user=user, file_name=file_name).first()
@staticmethod @staticmethod
@require_valid_user
def get_all_file_objects(user: KhojUser): def get_all_file_objects(user: KhojUser):
return FileObject.objects.filter(user=user).all() return FileObject.objects.filter(user=user).all()
@staticmethod @staticmethod
@require_valid_user
def delete_file_object_by_name(user: KhojUser, file_name: str): def delete_file_object_by_name(user: KhojUser, file_name: str):
return FileObject.objects.filter(user=user, file_name=file_name).delete() return FileObject.objects.filter(user=user, file_name=file_name).delete()
@staticmethod @staticmethod
@require_valid_user
def delete_all_file_objects(user: KhojUser): def delete_all_file_objects(user: KhojUser):
return FileObject.objects.filter(user=user).delete() return FileObject.objects.filter(user=user).delete()
@staticmethod @staticmethod
async def async_update_raw_text(file_object: FileObject, new_raw_text: str): async def aupdate_raw_text(file_object: FileObject, new_raw_text: str):
file_object.raw_text = new_raw_text file_object.raw_text = new_raw_text
await file_object.asave() await file_object.asave()
@staticmethod @staticmethod
async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str): @arequire_valid_user
async def acreate_file_object(user: KhojUser, file_name: str, raw_text: str):
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text) return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod @staticmethod
async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): @arequire_valid_user
async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent)) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
@staticmethod @staticmethod
async def async_get_file_objects_by_names(user: KhojUser, file_names: List[str]): @arequire_valid_user
async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]):
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names)) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names))
@staticmethod @staticmethod
async def async_get_all_file_objects(user: KhojUser): @arequire_valid_user
async def aget_all_file_objects(user: KhojUser):
return await sync_to_async(list)(FileObject.objects.filter(user=user)) return await sync_to_async(list)(FileObject.objects.filter(user=user))
@staticmethod @staticmethod
async def async_delete_file_object_by_name(user: KhojUser, file_name: str): @arequire_valid_user
async def adelete_file_object_by_name(user: KhojUser, file_name: str):
return await FileObject.objects.filter(user=user, file_name=file_name).adelete() return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
@staticmethod @staticmethod
async def async_delete_all_file_objects(user: KhojUser): @arequire_valid_user
async def adelete_all_file_objects(user: KhojUser):
return await FileObject.objects.filter(user=user).adelete() return await FileObject.objects.filter(user=user).adelete()
@ -1410,15 +1478,18 @@ class EntryAdapters:
date_filter = DateFilter() date_filter = DateFilter()
@staticmethod @staticmethod
@require_valid_user
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool: def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists() return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
@staticmethod @staticmethod
@require_valid_user
def delete_entry_by_file(user: KhojUser, file_path: str): def delete_entry_by_file(user: KhojUser, file_path: str):
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete() deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
return deleted_count return deleted_count
@staticmethod @staticmethod
@require_valid_user
def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None): def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None):
queryset = Entry.objects.filter(user=user) queryset = Entry.objects.filter(user=user)
@ -1431,6 +1502,7 @@ class EntryAdapters:
return queryset return queryset
@staticmethod @staticmethod
@require_valid_user
def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
deleted_count = 0 deleted_count = 0
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
@ -1442,6 +1514,7 @@ class EntryAdapters:
return deleted_count return deleted_count
@staticmethod @staticmethod
@arequire_valid_user
async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
deleted_count = 0 deleted_count = 0
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
@ -1453,10 +1526,12 @@ class EntryAdapters:
return deleted_count return deleted_count
@staticmethod @staticmethod
@require_valid_user
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@staticmethod @staticmethod
@require_valid_user
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]): def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete() Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
@ -1468,6 +1543,7 @@ class EntryAdapters:
) )
@staticmethod @staticmethod
@require_valid_user
def user_has_entries(user: KhojUser): def user_has_entries(user: KhojUser):
return Entry.objects.filter(user=user).exists() return Entry.objects.filter(user=user).exists()
@ -1476,6 +1552,7 @@ class EntryAdapters:
return Entry.objects.filter(agent=agent).exists() return Entry.objects.filter(agent=agent).exists()
@staticmethod @staticmethod
@arequire_valid_user
async def auser_has_entries(user: KhojUser): async def auser_has_entries(user: KhojUser):
return await Entry.objects.filter(user=user).aexists() return await Entry.objects.filter(user=user).aexists()
@ -1486,10 +1563,12 @@ class EntryAdapters:
return await Entry.objects.filter(agent=agent).aexists() return await Entry.objects.filter(agent=agent).aexists()
@staticmethod @staticmethod
@arequire_valid_user
async def adelete_entry_by_file(user: KhojUser, file_path: str): async def adelete_entry_by_file(user: KhojUser, file_path: str):
return await Entry.objects.filter(user=user, file_path=file_path).adelete() return await Entry.objects.filter(user=user, file_path=file_path).adelete()
@staticmethod @staticmethod
@arequire_valid_user
async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000): async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000):
deleted_count = 0 deleted_count = 0
for i in range(0, len(filenames), batch_size): for i in range(0, len(filenames), batch_size):
@ -1508,6 +1587,7 @@ class EntryAdapters:
) )
@staticmethod @staticmethod
@require_valid_user
def get_all_filenames_by_source(user: KhojUser, file_source: str): def get_all_filenames_by_source(user: KhojUser, file_source: str):
return ( return (
Entry.objects.filter(user=user, file_source=file_source) Entry.objects.filter(user=user, file_source=file_source)
@ -1516,6 +1596,7 @@ class EntryAdapters:
) )
@staticmethod @staticmethod
@require_valid_user
def get_size_of_indexed_data_in_mb(user: KhojUser): def get_size_of_indexed_data_in_mb(user: KhojUser):
entries = Entry.objects.filter(user=user).iterator() entries = Entry.objects.filter(user=user).iterator()
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries) total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
@ -1536,6 +1617,9 @@ class EntryAdapters:
if agent != None: if agent != None:
owner_filter |= Q(agent=agent) owner_filter |= Q(agent=agent)
if owner_filter == Q():
return Entry.objects.none()
if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0: if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
return Entry.objects.filter(owner_filter) return Entry.objects.filter(owner_filter)
@ -1610,10 +1694,12 @@ class EntryAdapters:
return relevant_entries[:max_results] return relevant_entries[:max_results]
@staticmethod @staticmethod
@require_valid_user
def get_unique_file_types(user: KhojUser): def get_unique_file_types(user: KhojUser):
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct() return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
@staticmethod @staticmethod
@require_valid_user
def get_unique_file_sources(user: KhojUser): def get_unique_file_sources(user: KhojUser):
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all() return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all()

View file

@ -18,7 +18,7 @@ class DocxToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == b""]) deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names files_to_process = set(files) - deletion_file_names
@ -35,13 +35,13 @@ class DocxToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.DOCX, DbEntry.EntryType.DOCX,
DbEntry.EntrySource.COMPUTER, DbEntry.EntrySource.COMPUTER,
"compiled", "compiled",
logger, logger,
deletion_file_names, deletion_file_names,
user,
regenerate=regenerate, regenerate=regenerate,
file_to_text_map=file_to_text_map, file_to_text_map=file_to_text_map,
) )

View file

@ -48,7 +48,7 @@ class GithubToEntries(TextToEntries):
else: else:
return return
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
if self.config.pat_token is None or self.config.pat_token == "": if self.config.pat_token is None or self.config.pat_token == "":
logger.error(f"Github PAT token is not set. Skipping github content") logger.error(f"Github PAT token is not set. Skipping github content")
raise ValueError("Github PAT token is not set. Skipping github content") raise ValueError("Github PAT token is not set. Skipping github content")
@ -101,12 +101,12 @@ class GithubToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.GITHUB, DbEntry.EntryType.GITHUB,
DbEntry.EntrySource.GITHUB, DbEntry.EntrySource.GITHUB,
key="compiled", key="compiled",
logger=logger, logger=logger,
user=user,
) )
return num_new_embeddings, num_deleted_embeddings return num_new_embeddings, num_deleted_embeddings

View file

@ -18,7 +18,7 @@ class ImageToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == b""]) deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names files_to_process = set(files) - deletion_file_names
@ -35,13 +35,13 @@ class ImageToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.IMAGE, DbEntry.EntryType.IMAGE,
DbEntry.EntrySource.COMPUTER, DbEntry.EntrySource.COMPUTER,
"compiled", "compiled",
logger, logger,
deletion_file_names, deletion_file_names,
user,
regenerate=regenerate, regenerate=regenerate,
file_to_text_map=file_to_text_map, file_to_text_map=file_to_text_map,
) )

View file

@ -19,7 +19,7 @@ class MarkdownToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == ""]) deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names files_to_process = set(files) - deletion_file_names
@ -37,13 +37,13 @@ class MarkdownToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.MARKDOWN, DbEntry.EntryType.MARKDOWN,
DbEntry.EntrySource.COMPUTER, DbEntry.EntrySource.COMPUTER,
"compiled", "compiled",
logger, logger,
deletion_file_names, deletion_file_names,
user,
regenerate=regenerate, regenerate=regenerate,
file_to_text_map=file_to_text_map, file_to_text_map=file_to_text_map,
) )

View file

@ -79,7 +79,7 @@ class NotionToEntries(TextToEntries):
self.body_params = {"page_size": 100} self.body_params = {"page_size": 100}
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
current_entries = [] current_entries = []
# Get all pages # Get all pages
@ -248,12 +248,12 @@ class NotionToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.NOTION, DbEntry.EntryType.NOTION,
DbEntry.EntrySource.NOTION, DbEntry.EntrySource.NOTION,
key="compiled", key="compiled",
logger=logger, logger=logger,
user=user,
) )
return num_new_embeddings, num_deleted_embeddings return num_new_embeddings, num_deleted_embeddings

View file

@ -20,7 +20,7 @@ class OrgToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
deletion_file_names = set([file for file in files if files[file] == ""]) deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process} files = {file: files[file] for file in files_to_process}
@ -36,13 +36,13 @@ class OrgToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.ORG, DbEntry.EntryType.ORG,
DbEntry.EntrySource.COMPUTER, DbEntry.EntrySource.COMPUTER,
"compiled", "compiled",
logger, logger,
deletion_file_names, deletion_file_names,
user,
regenerate=regenerate, regenerate=regenerate,
file_to_text_map=file_to_text_map, file_to_text_map=file_to_text_map,
) )

View file

@ -19,7 +19,7 @@ class PdfToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == b""]) deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names files_to_process = set(files) - deletion_file_names
@ -36,13 +36,13 @@ class PdfToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.PDF, DbEntry.EntryType.PDF,
DbEntry.EntrySource.COMPUTER, DbEntry.EntrySource.COMPUTER,
"compiled", "compiled",
logger, logger,
deletion_file_names, deletion_file_names,
user,
regenerate=regenerate, regenerate=regenerate,
file_to_text_map=file_to_text_map, file_to_text_map=file_to_text_map,
) )

View file

@ -20,7 +20,7 @@ class PlaintextToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
deletion_file_names = set([file for file in files if files[file] == ""]) deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process} files = {file: files[file] for file in files_to_process}
@ -36,13 +36,13 @@ class PlaintextToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.PLAINTEXT, DbEntry.EntryType.PLAINTEXT,
DbEntry.EntrySource.COMPUTER, DbEntry.EntrySource.COMPUTER,
key="compiled", key="compiled",
logger=logger, logger=logger,
deletion_filenames=deletion_file_names, deletion_filenames=deletion_file_names,
user=user,
regenerate=regenerate, regenerate=regenerate,
file_to_text_map=file_to_text_map, file_to_text_map=file_to_text_map,
) )

View file

@ -31,7 +31,7 @@ class TextToEntries(ABC):
self.date_filter = DateFilter() self.date_filter = DateFilter()
@abstractmethod @abstractmethod
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
... ...
@staticmethod @staticmethod
@ -114,13 +114,13 @@ class TextToEntries(ABC):
def update_embeddings( def update_embeddings(
self, self,
user: KhojUser,
current_entries: List[Entry], current_entries: List[Entry],
file_type: str, file_type: str,
file_source: str, file_source: str,
key="compiled", key="compiled",
logger: logging.Logger = None, logger: logging.Logger = None,
deletion_filenames: Set[str] = None, deletion_filenames: Set[str] = None,
user: KhojUser = None,
regenerate: bool = False, regenerate: bool = False,
file_to_text_map: dict[str, str] = None, file_to_text_map: dict[str, str] = None,
): ):

View file

@ -212,7 +212,7 @@ def update(
logger.warning(error_msg) logger.warning(error_msg)
raise HTTPException(status_code=500, detail=error_msg) raise HTTPException(status_code=500, detail=error_msg)
try: try:
initialize_content(regenerate=force, search_type=t, user=user) initialize_content(user=user, regenerate=force, search_type=t)
except Exception as e: except Exception as e:
error_msg = f"🚨 Failed to update server via API: {e}" error_msg = f"🚨 Failed to update server via API: {e}"
logger.error(error_msg, exc_info=True) logger.error(error_msg, exc_info=True)

View file

@ -239,7 +239,7 @@ async def set_content_notion(
if updated_config.token: if updated_config.token:
# Trigger an async job to configure_content. Let it run without blocking the response. # Trigger an async job to configure_content. Let it run without blocking the response.
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user) background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion)
update_telemetry_state( update_telemetry_state(
request=request, request=request,
@ -512,10 +512,10 @@ async def indexer(
success = await loop.run_in_executor( success = await loop.run_in_executor(
None, None,
configure_content, configure_content,
user,
indexer_input.model_dump(), indexer_input.model_dump(),
regenerate, regenerate,
t, t,
user,
) )
if not success: if not success:
raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index") raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index")

View file

@ -703,7 +703,7 @@ async def generate_summary_from_files(
if await EntryAdapters.aagent_has_entries(agent): if await EntryAdapters.aagent_has_entries(agent):
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent) file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
if len(file_names) > 0: if len(file_names) > 0:
file_objects = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent) file_objects = await FileObjectAdapters.aget_file_objects_by_name(None, file_names.pop(), agent)
if (file_objects and len(file_objects) == 0 and not query_files) or (not file_objects and not query_files): if (file_objects and len(file_objects) == 0 and not query_files) or (not file_objects and not query_files):
response_log = "Sorry, I couldn't find anything to summarize." response_log = "Sorry, I couldn't find anything to summarize."
@ -1975,10 +1975,10 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
def configure_content( def configure_content(
user: KhojUser,
files: Optional[dict[str, dict[str, str]]], files: Optional[dict[str, dict[str, str]]],
regenerate: bool = False, regenerate: bool = False,
t: Optional[state.SearchType] = state.SearchType.All, t: Optional[state.SearchType] = state.SearchType.All,
user: KhojUser = None,
) -> bool: ) -> bool:
success = True success = True
if t == None: if t == None:

View file

@ -80,6 +80,6 @@ async def notion_auth_callback(request: Request, background_tasks: BackgroundTas
notion_redirect = str(request.app.url_path_for("config_page")) notion_redirect = str(request.app.url_path_for("config_page"))
# Trigger an async job to configure_content. Let it run without blocking the response. # Trigger an async job to configure_content. Let it run without blocking the response.
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user) background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion)
return RedirectResponse(notion_redirect) return RedirectResponse(notion_redirect)

View file

@ -208,7 +208,7 @@ def setup(
text_to_entries: Type[TextToEntries], text_to_entries: Type[TextToEntries],
files: dict[str, str], files: dict[str, str],
regenerate: bool, regenerate: bool,
user: KhojUser = None, user: KhojUser,
config=None, config=None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
if config: if config:

View file

@ -8,6 +8,7 @@ from bs4 import BeautifulSoup
from magika import Magika from magika import Magika
from khoj.database.models import ( from khoj.database.models import (
KhojUser,
LocalMarkdownConfig, LocalMarkdownConfig,
LocalOrgConfig, LocalOrgConfig,
LocalPdfConfig, LocalPdfConfig,
@ -21,7 +22,7 @@ logger = logging.getLogger(__name__)
magika = Magika() magika = Magika()
def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict: def collect_files(user: KhojUser, search_type: Optional[SearchType] = SearchType.All) -> dict:
files: dict[str, dict] = {"docx": {}, "image": {}} files: dict[str, dict] = {"docx": {}, "image": {}}
if search_type == SearchType.All or search_type == SearchType.Org: if search_type == SearchType.All or search_type == SearchType.Org:

View file

@ -304,7 +304,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
# Index Markdown Content for Search # Index Markdown Content for Search
all_files = fs_syncer.collect_files(user=user) all_files = fs_syncer.collect_files(user=user)
success = configure_content(all_files, user=user) configure_content(user, all_files)
# Initialize Processor from Config # Initialize Processor from Config
if os.getenv("OPENAI_API_KEY"): if os.getenv("OPENAI_API_KEY"):
@ -381,7 +381,7 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
) )
all_files = fs_syncer.collect_files(user=default_user2) all_files = fs_syncer.collect_files(user=default_user2)
configure_content(all_files, user=default_user2) configure_content(default_user2, all_files)
# Initialize Processor from Config # Initialize Processor from Config
ChatModelOptionsFactory( ChatModelOptionsFactory(
@ -432,7 +432,7 @@ def pdf_configured_user1(default_user: KhojUser):
) )
# Index Markdown Content for Search # Index Markdown Content for Search
all_files = fs_syncer.collect_files(user=default_user) all_files = fs_syncer.collect_files(user=default_user)
success = configure_content(all_files, user=default_user) configure_content(default_user, all_files)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")

View file

@ -253,11 +253,11 @@ def test_regenerate_with_github_fails_without_pat(client):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db @pytest.mark.django_db
def test_get_configured_types_via_api(client, sample_org_data): def test_get_configured_types_via_api(client, sample_org_data, default_user3: KhojUser):
# Act # Act
text_search.setup(OrgToEntries, sample_org_data, regenerate=False) text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user3)
enabled_types = EntryAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True) enabled_types = EntryAdapters.get_unique_file_types(user=default_user3).all().values_list("file_type", flat=True)
# Assert # Assert
assert list(enabled_types) == ["org"] assert list(enabled_types) == ["org"]