Do not CRUD on entries, files & conversations in DB for null user

Increase defense-in-depth by reducing paths to create, read, update or
delete entries, files and conversations in DB when user is unset.
This commit is contained in:
Debanjum 2024-11-04 19:45:28 -08:00
parent 27fa39353e
commit ff5c10c221
19 changed files with 92 additions and 47 deletions

View file

@ -253,7 +253,7 @@ def configure_server(
logger.info(message) logger.info(message)
if not init: if not init:
initialize_content(regenerate, search_type, user) initialize_content(user, regenerate, search_type)
except Exception as e: except Exception as e:
logger.error(f"Failed to load some search models: {e}", exc_info=True) logger.error(f"Failed to load some search models: {e}", exc_info=True)
@ -263,17 +263,17 @@ def setup_default_agent(user: KhojUser):
AgentAdapters.create_default_agent(user) AgentAdapters.create_default_agent(user)
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None): def initialize_content(user: KhojUser, regenerate: bool, search_type: Optional[SearchType] = None):
# Initialize Content from Config # Initialize Content from Config
if state.search_models: if state.search_models:
try: try:
logger.info("📬 Updating content index...") logger.info("📬 Updating content index...")
all_files = collect_files(user=user) all_files = collect_files(user=user)
status = configure_content( status = configure_content(
user,
all_files, all_files,
regenerate, regenerate,
search_type, search_type,
user=user,
) )
if not status: if not status:
raise RuntimeError("Failed to update content index") raise RuntimeError("Failed to update content index")
@ -338,9 +338,7 @@ def configure_middleware(app):
def update_content_index(): def update_content_index():
for user in get_all_users(): for user in get_all_users():
all_files = collect_files(user=user) all_files = collect_files(user=user)
success = configure_content(all_files, user=user) success = configure_content(user, all_files)
all_files = collect_files(user=None)
success = configure_content(all_files, user=None)
if not success: if not success:
raise RuntimeError("Failed to update content index") raise RuntimeError("Failed to update content index")
logger.info("📪 Content index updated via Scheduler") logger.info("📪 Content index updated via Scheduler")

View file

@ -416,6 +416,11 @@ def get_all_users() -> BaseManager[KhojUser]:
return KhojUser.objects.all() return KhojUser.objects.all()
def check_valid_user(user: KhojUser | None):
if not user:
raise ValueError("User not found")
def get_user_github_config(user: KhojUser): def get_user_github_config(user: KhojUser):
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
return config return config
@ -815,6 +820,7 @@ class ConversationAdapters:
def get_conversation_by_user( def get_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
) -> Optional[Conversation]: ) -> Optional[Conversation]:
check_valid_user(user)
if conversation_id: if conversation_id:
conversation = ( conversation = (
Conversation.objects.filter(user=user, client=client_application, id=conversation_id) Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
@ -831,6 +837,7 @@ class ConversationAdapters:
@staticmethod @staticmethod
def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None): def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None):
check_valid_user(user)
return ( return (
Conversation.objects.filter(user=user, client=client_application) Conversation.objects.filter(user=user, client=client_application)
.prefetch_related("agent") .prefetch_related("agent")
@ -841,6 +848,7 @@ class ConversationAdapters:
async def aset_conversation_title( async def aset_conversation_title(
user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str
): ):
check_valid_user(user)
conversation = await Conversation.objects.filter( conversation = await Conversation.objects.filter(
user=user, client=client_application, id=conversation_id user=user, client=client_application, id=conversation_id
).afirst() ).afirst()
@ -858,6 +866,7 @@ class ConversationAdapters:
async def acreate_conversation_session( async def acreate_conversation_session(
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
): ):
check_valid_user(user)
if agent_slug: if agent_slug:
agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user) agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
if agent is None: if agent is None:
@ -874,6 +883,7 @@ class ConversationAdapters:
def create_conversation_session( def create_conversation_session(
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
): ):
check_valid_user(user)
if agent_slug: if agent_slug:
agent = AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user) agent = AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
if agent is None: if agent is None:
@ -890,6 +900,7 @@ class ConversationAdapters:
title: str = None, title: str = None,
create_new: bool = False, create_new: bool = False,
) -> Optional[Conversation]: ) -> Optional[Conversation]:
check_valid_user(user)
if create_new: if create_new:
return await ConversationAdapters.acreate_conversation_session(user, client_application) return await ConversationAdapters.acreate_conversation_session(user, client_application)
@ -910,12 +921,14 @@ class ConversationAdapters:
async def adelete_conversation_by_user( async def adelete_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
): ):
check_valid_user(user)
if conversation_id: if conversation_id:
return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).adelete() return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).adelete()
return await Conversation.objects.filter(user=user, client=client_application).adelete() return await Conversation.objects.filter(user=user, client=client_application).adelete()
@staticmethod @staticmethod
def has_any_conversation_config(user: KhojUser): def has_any_conversation_config(user: KhojUser):
check_valid_user(user)
return ChatModelOptions.objects.filter(user=user).exists() return ChatModelOptions.objects.filter(user=user).exists()
@staticmethod @staticmethod
@ -953,7 +966,7 @@ class ConversationAdapters:
@staticmethod @staticmethod
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int): async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst() config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
if not config: if not config or user is None:
return None return None
new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config return new_config
@ -961,7 +974,7 @@ class ConversationAdapters:
@staticmethod @staticmethod
async def aset_user_voice_model(user: KhojUser, model_id: str): async def aset_user_voice_model(user: KhojUser, model_id: str):
config = await VoiceModelOption.objects.filter(model_id=model_id).afirst() config = await VoiceModelOption.objects.filter(model_id=model_id).afirst()
if not config: if not config or user is None:
return None return None
new_config = await UserVoiceModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) new_config = await UserVoiceModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config return new_config
@ -1146,6 +1159,7 @@ class ConversationAdapters:
def create_conversation_from_public_conversation( def create_conversation_from_public_conversation(
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication
): ):
check_valid_user(user)
scrubbed_title = public_conversation.title if public_conversation.title else public_conversation.slug scrubbed_title = public_conversation.title if public_conversation.title else public_conversation.slug
if scrubbed_title: if scrubbed_title:
scrubbed_title = scrubbed_title.replace("-", " ") scrubbed_title = scrubbed_title.replace("-", " ")
@ -1166,6 +1180,7 @@ class ConversationAdapters:
conversation_id: str = None, conversation_id: str = None,
user_message: str = None, user_message: str = None,
): ):
check_valid_user(user)
slug = user_message.strip()[:200] if user_message else None slug = user_message.strip()[:200] if user_message else None
if conversation_id: if conversation_id:
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first() conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first()
@ -1209,6 +1224,7 @@ class ConversationAdapters:
@staticmethod @staticmethod
async def aget_conversation_starters(user: KhojUser, max_results=3): async def aget_conversation_starters(user: KhojUser, max_results=3):
check_valid_user(user)
all_questions = [] all_questions = []
if await ReflectiveQuestion.objects.filter(user=user).aexists(): if await ReflectiveQuestion.objects.filter(user=user).aexists():
all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=user).values_list)( all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=user).values_list)(
@ -1338,6 +1354,7 @@ class ConversationAdapters:
@staticmethod @staticmethod
def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str): def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str):
check_valid_user(user)
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id) conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"): if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"):
return False return False
@ -1356,51 +1373,62 @@ class FileObjectAdapters:
@staticmethod @staticmethod
def create_file_object(user: KhojUser, file_name: str, raw_text: str): def create_file_object(user: KhojUser, file_name: str, raw_text: str):
check_valid_user(user)
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text) return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod @staticmethod
def get_file_object_by_name(user: KhojUser, file_name: str): def get_file_object_by_name(user: KhojUser, file_name: str):
check_valid_user(user)
return FileObject.objects.filter(user=user, file_name=file_name).first() return FileObject.objects.filter(user=user, file_name=file_name).first()
@staticmethod @staticmethod
def get_all_file_objects(user: KhojUser): def get_all_file_objects(user: KhojUser):
check_valid_user(user)
return FileObject.objects.filter(user=user).all() return FileObject.objects.filter(user=user).all()
@staticmethod @staticmethod
def delete_file_object_by_name(user: KhojUser, file_name: str): def delete_file_object_by_name(user: KhojUser, file_name: str):
check_valid_user(user)
return FileObject.objects.filter(user=user, file_name=file_name).delete() return FileObject.objects.filter(user=user, file_name=file_name).delete()
@staticmethod @staticmethod
def delete_all_file_objects(user: KhojUser): def delete_all_file_objects(user: KhojUser):
check_valid_user(user)
return FileObject.objects.filter(user=user).delete() return FileObject.objects.filter(user=user).delete()
@staticmethod @staticmethod
async def async_update_raw_text(file_object: FileObject, new_raw_text: str): async def aupdate_raw_text(file_object: FileObject, new_raw_text: str):
file_object.raw_text = new_raw_text file_object.raw_text = new_raw_text
await file_object.asave() await file_object.asave()
@staticmethod @staticmethod
async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str): async def acreate_file_object(user: KhojUser, file_name: str, raw_text: str):
check_valid_user(user)
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text) return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod @staticmethod
async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
check_valid_user(user)
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent)) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
@staticmethod @staticmethod
async def async_get_file_objects_by_names(user: KhojUser, file_names: List[str]): async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]):
check_valid_user(user)
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names)) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names))
@staticmethod @staticmethod
async def async_get_all_file_objects(user: KhojUser): async def aget_all_file_objects(user: KhojUser):
check_valid_user(user)
return await sync_to_async(list)(FileObject.objects.filter(user=user)) return await sync_to_async(list)(FileObject.objects.filter(user=user))
@staticmethod @staticmethod
async def async_delete_file_object_by_name(user: KhojUser, file_name: str): async def adelete_file_object_by_name(user: KhojUser, file_name: str):
check_valid_user(user)
return await FileObject.objects.filter(user=user, file_name=file_name).adelete() return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
@staticmethod @staticmethod
async def async_delete_all_file_objects(user: KhojUser): async def adelete_all_file_objects(user: KhojUser):
check_valid_user(user)
return await FileObject.objects.filter(user=user).adelete() return await FileObject.objects.filter(user=user).adelete()
@ -1411,15 +1439,18 @@ class EntryAdapters:
@staticmethod @staticmethod
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool: def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
check_valid_user(user)
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists() return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
@staticmethod @staticmethod
def delete_entry_by_file(user: KhojUser, file_path: str): def delete_entry_by_file(user: KhojUser, file_path: str):
check_valid_user(user)
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete() deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
return deleted_count return deleted_count
@staticmethod @staticmethod
def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None): def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None):
check_valid_user(user)
queryset = Entry.objects.filter(user=user) queryset = Entry.objects.filter(user=user)
if file_type is not None: if file_type is not None:
@ -1432,6 +1463,7 @@ class EntryAdapters:
@staticmethod @staticmethod
def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
check_valid_user(user)
deleted_count = 0 deleted_count = 0
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
while queryset.exists(): while queryset.exists():
@ -1443,6 +1475,7 @@ class EntryAdapters:
@staticmethod @staticmethod
async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
check_valid_user(user)
deleted_count = 0 deleted_count = 0
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
while await queryset.aexists(): while await queryset.aexists():
@ -1454,10 +1487,12 @@ class EntryAdapters:
@staticmethod @staticmethod
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
check_valid_user(user)
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@staticmethod @staticmethod
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]): def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
check_valid_user(user)
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete() Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
@staticmethod @staticmethod
@ -1469,6 +1504,7 @@ class EntryAdapters:
@staticmethod @staticmethod
def user_has_entries(user: KhojUser): def user_has_entries(user: KhojUser):
check_valid_user(user)
return Entry.objects.filter(user=user).exists() return Entry.objects.filter(user=user).exists()
@staticmethod @staticmethod
@ -1477,6 +1513,7 @@ class EntryAdapters:
@staticmethod @staticmethod
async def auser_has_entries(user: KhojUser): async def auser_has_entries(user: KhojUser):
check_valid_user(user)
return await Entry.objects.filter(user=user).aexists() return await Entry.objects.filter(user=user).aexists()
@staticmethod @staticmethod
@ -1487,10 +1524,12 @@ class EntryAdapters:
@staticmethod @staticmethod
async def adelete_entry_by_file(user: KhojUser, file_path: str): async def adelete_entry_by_file(user: KhojUser, file_path: str):
check_valid_user(user)
return await Entry.objects.filter(user=user, file_path=file_path).adelete() return await Entry.objects.filter(user=user, file_path=file_path).adelete()
@staticmethod @staticmethod
async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000): async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000):
check_valid_user(user)
deleted_count = 0 deleted_count = 0
for i in range(0, len(filenames), batch_size): for i in range(0, len(filenames), batch_size):
batch = filenames[i : i + batch_size] batch = filenames[i : i + batch_size]
@ -1509,6 +1548,7 @@ class EntryAdapters:
@staticmethod @staticmethod
def get_all_filenames_by_source(user: KhojUser, file_source: str): def get_all_filenames_by_source(user: KhojUser, file_source: str):
check_valid_user(user)
return ( return (
Entry.objects.filter(user=user, file_source=file_source) Entry.objects.filter(user=user, file_source=file_source)
.distinct("file_path") .distinct("file_path")
@ -1517,6 +1557,7 @@ class EntryAdapters:
@staticmethod @staticmethod
def get_size_of_indexed_data_in_mb(user: KhojUser): def get_size_of_indexed_data_in_mb(user: KhojUser):
check_valid_user(user)
entries = Entry.objects.filter(user=user).iterator() entries = Entry.objects.filter(user=user).iterator()
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries) total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
return total_size / 1024 / 1024 return total_size / 1024 / 1024
@ -1536,6 +1577,9 @@ class EntryAdapters:
if agent != None: if agent != None:
owner_filter |= Q(agent=agent) owner_filter |= Q(agent=agent)
if owner_filter == Q():
return Entry.objects.none()
if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0: if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
return Entry.objects.filter(owner_filter) return Entry.objects.filter(owner_filter)
@ -1611,10 +1655,12 @@ class EntryAdapters:
@staticmethod @staticmethod
def get_unique_file_types(user: KhojUser): def get_unique_file_types(user: KhojUser):
check_valid_user(user)
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct() return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
@staticmethod @staticmethod
def get_unique_file_sources(user: KhojUser): def get_unique_file_sources(user: KhojUser):
check_valid_user(user)
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all() return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all()

View file

@ -18,7 +18,7 @@ class DocxToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == b""]) deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names files_to_process = set(files) - deletion_file_names
@ -35,13 +35,13 @@ class DocxToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.DOCX, DbEntry.EntryType.DOCX,
DbEntry.EntrySource.COMPUTER, DbEntry.EntrySource.COMPUTER,
"compiled", "compiled",
logger, logger,
deletion_file_names, deletion_file_names,
user,
regenerate=regenerate, regenerate=regenerate,
file_to_text_map=file_to_text_map, file_to_text_map=file_to_text_map,
) )

View file

@ -48,7 +48,7 @@ class GithubToEntries(TextToEntries):
else: else:
return return
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
if self.config.pat_token is None or self.config.pat_token == "": if self.config.pat_token is None or self.config.pat_token == "":
logger.error(f"Github PAT token is not set. Skipping github content") logger.error(f"Github PAT token is not set. Skipping github content")
raise ValueError("Github PAT token is not set. Skipping github content") raise ValueError("Github PAT token is not set. Skipping github content")
@ -101,12 +101,12 @@ class GithubToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.GITHUB, DbEntry.EntryType.GITHUB,
DbEntry.EntrySource.GITHUB, DbEntry.EntrySource.GITHUB,
key="compiled", key="compiled",
logger=logger, logger=logger,
user=user,
) )
return num_new_embeddings, num_deleted_embeddings return num_new_embeddings, num_deleted_embeddings

View file

@ -18,7 +18,7 @@ class ImageToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == b""]) deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names files_to_process = set(files) - deletion_file_names
@ -35,13 +35,13 @@ class ImageToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.IMAGE, DbEntry.EntryType.IMAGE,
DbEntry.EntrySource.COMPUTER, DbEntry.EntrySource.COMPUTER,
"compiled", "compiled",
logger, logger,
deletion_file_names, deletion_file_names,
user,
regenerate=regenerate, regenerate=regenerate,
file_to_text_map=file_to_text_map, file_to_text_map=file_to_text_map,
) )

View file

@ -19,7 +19,7 @@ class MarkdownToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == ""]) deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names files_to_process = set(files) - deletion_file_names
@ -37,13 +37,13 @@ class MarkdownToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.MARKDOWN, DbEntry.EntryType.MARKDOWN,
DbEntry.EntrySource.COMPUTER, DbEntry.EntrySource.COMPUTER,
"compiled", "compiled",
logger, logger,
deletion_file_names, deletion_file_names,
user,
regenerate=regenerate, regenerate=regenerate,
file_to_text_map=file_to_text_map, file_to_text_map=file_to_text_map,
) )

View file

@ -79,7 +79,7 @@ class NotionToEntries(TextToEntries):
self.body_params = {"page_size": 100} self.body_params = {"page_size": 100}
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
current_entries = [] current_entries = []
# Get all pages # Get all pages
@ -248,12 +248,12 @@ class NotionToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.NOTION, DbEntry.EntryType.NOTION,
DbEntry.EntrySource.NOTION, DbEntry.EntrySource.NOTION,
key="compiled", key="compiled",
logger=logger, logger=logger,
user=user,
) )
return num_new_embeddings, num_deleted_embeddings return num_new_embeddings, num_deleted_embeddings

View file

@ -20,7 +20,7 @@ class OrgToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
deletion_file_names = set([file for file in files if files[file] == ""]) deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process} files = {file: files[file] for file in files_to_process}
@ -36,13 +36,13 @@ class OrgToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.ORG, DbEntry.EntryType.ORG,
DbEntry.EntrySource.COMPUTER, DbEntry.EntrySource.COMPUTER,
"compiled", "compiled",
logger, logger,
deletion_file_names, deletion_file_names,
user,
regenerate=regenerate, regenerate=regenerate,
file_to_text_map=file_to_text_map, file_to_text_map=file_to_text_map,
) )

View file

@ -19,7 +19,7 @@ class PdfToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == b""]) deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names files_to_process = set(files) - deletion_file_names
@ -36,13 +36,13 @@ class PdfToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.PDF, DbEntry.EntryType.PDF,
DbEntry.EntrySource.COMPUTER, DbEntry.EntrySource.COMPUTER,
"compiled", "compiled",
logger, logger,
deletion_file_names, deletion_file_names,
user,
regenerate=regenerate, regenerate=regenerate,
file_to_text_map=file_to_text_map, file_to_text_map=file_to_text_map,
) )

View file

@ -20,7 +20,7 @@ class PlaintextToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
deletion_file_names = set([file for file in files if files[file] == ""]) deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process} files = {file: files[file] for file in files_to_process}
@ -36,13 +36,13 @@ class PlaintextToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries, current_entries,
DbEntry.EntryType.PLAINTEXT, DbEntry.EntryType.PLAINTEXT,
DbEntry.EntrySource.COMPUTER, DbEntry.EntrySource.COMPUTER,
key="compiled", key="compiled",
logger=logger, logger=logger,
deletion_filenames=deletion_file_names, deletion_filenames=deletion_file_names,
user=user,
regenerate=regenerate, regenerate=regenerate,
file_to_text_map=file_to_text_map, file_to_text_map=file_to_text_map,
) )

View file

@ -31,7 +31,7 @@ class TextToEntries(ABC):
self.date_filter = DateFilter() self.date_filter = DateFilter()
@abstractmethod @abstractmethod
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
... ...
@staticmethod @staticmethod
@ -114,13 +114,13 @@ class TextToEntries(ABC):
def update_embeddings( def update_embeddings(
self, self,
user: KhojUser,
current_entries: List[Entry], current_entries: List[Entry],
file_type: str, file_type: str,
file_source: str, file_source: str,
key="compiled", key="compiled",
logger: logging.Logger = None, logger: logging.Logger = None,
deletion_filenames: Set[str] = None, deletion_filenames: Set[str] = None,
user: KhojUser = None,
regenerate: bool = False, regenerate: bool = False,
file_to_text_map: dict[str, str] = None, file_to_text_map: dict[str, str] = None,
): ):

View file

@ -212,7 +212,7 @@ def update(
logger.warning(error_msg) logger.warning(error_msg)
raise HTTPException(status_code=500, detail=error_msg) raise HTTPException(status_code=500, detail=error_msg)
try: try:
initialize_content(regenerate=force, search_type=t, user=user) initialize_content(user=user, regenerate=force, search_type=t)
except Exception as e: except Exception as e:
error_msg = f"🚨 Failed to update server via API: {e}" error_msg = f"🚨 Failed to update server via API: {e}"
logger.error(error_msg, exc_info=True) logger.error(error_msg, exc_info=True)

View file

@ -239,7 +239,7 @@ async def set_content_notion(
if updated_config.token: if updated_config.token:
# Trigger an async job to configure_content. Let it run without blocking the response. # Trigger an async job to configure_content. Let it run without blocking the response.
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user) background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion)
update_telemetry_state( update_telemetry_state(
request=request, request=request,
@ -512,10 +512,10 @@ async def indexer(
success = await loop.run_in_executor( success = await loop.run_in_executor(
None, None,
configure_content, configure_content,
user,
indexer_input.model_dump(), indexer_input.model_dump(),
regenerate, regenerate,
t, t,
user,
) )
if not success: if not success:
raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index") raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index")

View file

@ -703,7 +703,7 @@ async def generate_summary_from_files(
if await EntryAdapters.aagent_has_entries(agent): if await EntryAdapters.aagent_has_entries(agent):
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent) file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
if len(file_names) > 0: if len(file_names) > 0:
file_objects = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent) file_objects = await FileObjectAdapters.aget_file_objects_by_name(None, file_names.pop(), agent)
if (file_objects and len(file_objects) == 0 and not query_files) or (not file_objects and not query_files): if (file_objects and len(file_objects) == 0 and not query_files) or (not file_objects and not query_files):
response_log = "Sorry, I couldn't find anything to summarize." response_log = "Sorry, I couldn't find anything to summarize."
@ -1975,10 +1975,10 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
def configure_content( def configure_content(
user: KhojUser,
files: Optional[dict[str, dict[str, str]]], files: Optional[dict[str, dict[str, str]]],
regenerate: bool = False, regenerate: bool = False,
t: Optional[state.SearchType] = state.SearchType.All, t: Optional[state.SearchType] = state.SearchType.All,
user: KhojUser = None,
) -> bool: ) -> bool:
success = True success = True
if t == None: if t == None:

View file

@ -80,6 +80,6 @@ async def notion_auth_callback(request: Request, background_tasks: BackgroundTas
notion_redirect = str(request.app.url_path_for("config_page")) notion_redirect = str(request.app.url_path_for("config_page"))
# Trigger an async job to configure_content. Let it run without blocking the response. # Trigger an async job to configure_content. Let it run without blocking the response.
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user) background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion)
return RedirectResponse(notion_redirect) return RedirectResponse(notion_redirect)

View file

@ -208,7 +208,7 @@ def setup(
text_to_entries: Type[TextToEntries], text_to_entries: Type[TextToEntries],
files: dict[str, str], files: dict[str, str],
regenerate: bool, regenerate: bool,
user: KhojUser = None, user: KhojUser,
config=None, config=None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
if config: if config:

View file

@ -8,6 +8,7 @@ from bs4 import BeautifulSoup
from magika import Magika from magika import Magika
from khoj.database.models import ( from khoj.database.models import (
KhojUser,
LocalMarkdownConfig, LocalMarkdownConfig,
LocalOrgConfig, LocalOrgConfig,
LocalPdfConfig, LocalPdfConfig,
@ -21,7 +22,7 @@ logger = logging.getLogger(__name__)
magika = Magika() magika = Magika()
def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict: def collect_files(user: KhojUser, search_type: Optional[SearchType] = SearchType.All) -> dict:
files: dict[str, dict] = {"docx": {}, "image": {}} files: dict[str, dict] = {"docx": {}, "image": {}}
if search_type == SearchType.All or search_type == SearchType.Org: if search_type == SearchType.All or search_type == SearchType.Org:

View file

@ -304,7 +304,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
# Index Markdown Content for Search # Index Markdown Content for Search
all_files = fs_syncer.collect_files(user=user) all_files = fs_syncer.collect_files(user=user)
success = configure_content(all_files, user=user) configure_content(user, all_files)
# Initialize Processor from Config # Initialize Processor from Config
if os.getenv("OPENAI_API_KEY"): if os.getenv("OPENAI_API_KEY"):
@ -381,7 +381,7 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
) )
all_files = fs_syncer.collect_files(user=default_user2) all_files = fs_syncer.collect_files(user=default_user2)
configure_content(all_files, user=default_user2) configure_content(default_user2, all_files)
# Initialize Processor from Config # Initialize Processor from Config
ChatModelOptionsFactory( ChatModelOptionsFactory(
@ -432,7 +432,7 @@ def pdf_configured_user1(default_user: KhojUser):
) )
# Index Markdown Content for Search # Index Markdown Content for Search
all_files = fs_syncer.collect_files(user=default_user) all_files = fs_syncer.collect_files(user=default_user)
success = configure_content(all_files, user=default_user) configure_content(default_user, all_files)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")

View file

@ -253,11 +253,11 @@ def test_regenerate_with_github_fails_without_pat(client):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db @pytest.mark.django_db
def test_get_configured_types_via_api(client, sample_org_data): def test_get_configured_types_via_api(client, sample_org_data, default_user3: KhojUser):
# Act # Act
text_search.setup(OrgToEntries, sample_org_data, regenerate=False) text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user3)
enabled_types = EntryAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True) enabled_types = EntryAdapters.get_unique_file_types(user=default_user3).all().values_list("file_type", flat=True)
# Assert # Assert
assert list(enabled_types) == ["org"] assert list(enabled_types) == ["org"]