Do not CRUD on entries, files & conversations in DB for null user

Increase defense-in-depth by reducing paths to create, read, update or
delete entries, files and conversations in DB when user is unset.
This commit is contained in:
Debanjum 2024-11-04 19:45:28 -08:00
parent 27fa39353e
commit ff5c10c221
19 changed files with 92 additions and 47 deletions

View file

@ -253,7 +253,7 @@ def configure_server(
logger.info(message)
if not init:
initialize_content(regenerate, search_type, user)
initialize_content(user, regenerate, search_type)
except Exception as e:
logger.error(f"Failed to load some search models: {e}", exc_info=True)
@ -263,17 +263,17 @@ def setup_default_agent(user: KhojUser):
AgentAdapters.create_default_agent(user)
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None):
def initialize_content(user: KhojUser, regenerate: bool, search_type: Optional[SearchType] = None):
# Initialize Content from Config
if state.search_models:
try:
logger.info("📬 Updating content index...")
all_files = collect_files(user=user)
status = configure_content(
user,
all_files,
regenerate,
search_type,
user=user,
)
if not status:
raise RuntimeError("Failed to update content index")
@ -338,9 +338,7 @@ def configure_middleware(app):
def update_content_index():
for user in get_all_users():
all_files = collect_files(user=user)
success = configure_content(all_files, user=user)
all_files = collect_files(user=None)
success = configure_content(all_files, user=None)
success = configure_content(user, all_files)
if not success:
raise RuntimeError("Failed to update content index")
logger.info("📪 Content index updated via Scheduler")

View file

@ -416,6 +416,11 @@ def get_all_users() -> BaseManager[KhojUser]:
return KhojUser.objects.all()
def check_valid_user(user: KhojUser | None):
if not user:
raise ValueError("User not found")
def get_user_github_config(user: KhojUser):
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
return config
@ -815,6 +820,7 @@ class ConversationAdapters:
def get_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
) -> Optional[Conversation]:
check_valid_user(user)
if conversation_id:
conversation = (
Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
@ -831,6 +837,7 @@ class ConversationAdapters:
@staticmethod
def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None):
check_valid_user(user)
return (
Conversation.objects.filter(user=user, client=client_application)
.prefetch_related("agent")
@ -841,6 +848,7 @@ class ConversationAdapters:
async def aset_conversation_title(
user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str
):
check_valid_user(user)
conversation = await Conversation.objects.filter(
user=user, client=client_application, id=conversation_id
).afirst()
@ -858,6 +866,7 @@ class ConversationAdapters:
async def acreate_conversation_session(
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
):
check_valid_user(user)
if agent_slug:
agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
if agent is None:
@ -874,6 +883,7 @@ class ConversationAdapters:
def create_conversation_session(
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
):
check_valid_user(user)
if agent_slug:
agent = AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
if agent is None:
@ -890,6 +900,7 @@ class ConversationAdapters:
title: str = None,
create_new: bool = False,
) -> Optional[Conversation]:
check_valid_user(user)
if create_new:
return await ConversationAdapters.acreate_conversation_session(user, client_application)
@ -910,12 +921,14 @@ class ConversationAdapters:
async def adelete_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
):
check_valid_user(user)
if conversation_id:
return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).adelete()
return await Conversation.objects.filter(user=user, client=client_application).adelete()
@staticmethod
def has_any_conversation_config(user: KhojUser):
check_valid_user(user)
return ChatModelOptions.objects.filter(user=user).exists()
@staticmethod
@ -953,7 +966,7 @@ class ConversationAdapters:
@staticmethod
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
if not config:
if not config or user is None:
return None
new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config
@ -961,7 +974,7 @@ class ConversationAdapters:
@staticmethod
async def aset_user_voice_model(user: KhojUser, model_id: str):
config = await VoiceModelOption.objects.filter(model_id=model_id).afirst()
if not config:
if not config or user is None:
return None
new_config = await UserVoiceModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config
@ -1146,6 +1159,7 @@ class ConversationAdapters:
def create_conversation_from_public_conversation(
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication
):
check_valid_user(user)
scrubbed_title = public_conversation.title if public_conversation.title else public_conversation.slug
if scrubbed_title:
scrubbed_title = scrubbed_title.replace("-", " ")
@ -1166,6 +1180,7 @@ class ConversationAdapters:
conversation_id: str = None,
user_message: str = None,
):
check_valid_user(user)
slug = user_message.strip()[:200] if user_message else None
if conversation_id:
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first()
@ -1209,6 +1224,7 @@ class ConversationAdapters:
@staticmethod
async def aget_conversation_starters(user: KhojUser, max_results=3):
check_valid_user(user)
all_questions = []
if await ReflectiveQuestion.objects.filter(user=user).aexists():
all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=user).values_list)(
@ -1338,6 +1354,7 @@ class ConversationAdapters:
@staticmethod
def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str):
check_valid_user(user)
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"):
return False
@ -1356,51 +1373,62 @@ class FileObjectAdapters:
@staticmethod
def create_file_object(user: KhojUser, file_name: str, raw_text: str):
check_valid_user(user)
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod
def get_file_object_by_name(user: KhojUser, file_name: str):
check_valid_user(user)
return FileObject.objects.filter(user=user, file_name=file_name).first()
@staticmethod
def get_all_file_objects(user: KhojUser):
check_valid_user(user)
return FileObject.objects.filter(user=user).all()
@staticmethod
def delete_file_object_by_name(user: KhojUser, file_name: str):
check_valid_user(user)
return FileObject.objects.filter(user=user, file_name=file_name).delete()
@staticmethod
def delete_all_file_objects(user: KhojUser):
check_valid_user(user)
return FileObject.objects.filter(user=user).delete()
@staticmethod
async def async_update_raw_text(file_object: FileObject, new_raw_text: str):
async def aupdate_raw_text(file_object: FileObject, new_raw_text: str):
file_object.raw_text = new_raw_text
await file_object.asave()
@staticmethod
async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str):
async def acreate_file_object(user: KhojUser, file_name: str, raw_text: str):
check_valid_user(user)
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod
async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
check_valid_user(user)
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
@staticmethod
async def async_get_file_objects_by_names(user: KhojUser, file_names: List[str]):
async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]):
check_valid_user(user)
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names))
@staticmethod
async def async_get_all_file_objects(user: KhojUser):
async def aget_all_file_objects(user: KhojUser):
check_valid_user(user)
return await sync_to_async(list)(FileObject.objects.filter(user=user))
@staticmethod
async def async_delete_file_object_by_name(user: KhojUser, file_name: str):
async def adelete_file_object_by_name(user: KhojUser, file_name: str):
check_valid_user(user)
return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
@staticmethod
async def async_delete_all_file_objects(user: KhojUser):
async def adelete_all_file_objects(user: KhojUser):
check_valid_user(user)
return await FileObject.objects.filter(user=user).adelete()
@ -1411,15 +1439,18 @@ class EntryAdapters:
@staticmethod
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
check_valid_user(user)
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
@staticmethod
def delete_entry_by_file(user: KhojUser, file_path: str):
check_valid_user(user)
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
return deleted_count
@staticmethod
def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None):
check_valid_user(user)
queryset = Entry.objects.filter(user=user)
if file_type is not None:
@ -1432,6 +1463,7 @@ class EntryAdapters:
@staticmethod
def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
check_valid_user(user)
deleted_count = 0
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
while queryset.exists():
@ -1443,6 +1475,7 @@ class EntryAdapters:
@staticmethod
async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
check_valid_user(user)
deleted_count = 0
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
while await queryset.aexists():
@ -1454,10 +1487,12 @@ class EntryAdapters:
@staticmethod
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
check_valid_user(user)
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@staticmethod
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
check_valid_user(user)
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
@staticmethod
@ -1469,6 +1504,7 @@ class EntryAdapters:
@staticmethod
def user_has_entries(user: KhojUser):
check_valid_user(user)
return Entry.objects.filter(user=user).exists()
@staticmethod
@ -1477,6 +1513,7 @@ class EntryAdapters:
@staticmethod
async def auser_has_entries(user: KhojUser):
check_valid_user(user)
return await Entry.objects.filter(user=user).aexists()
@staticmethod
@ -1487,10 +1524,12 @@ class EntryAdapters:
@staticmethod
async def adelete_entry_by_file(user: KhojUser, file_path: str):
check_valid_user(user)
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
@staticmethod
async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000):
check_valid_user(user)
deleted_count = 0
for i in range(0, len(filenames), batch_size):
batch = filenames[i : i + batch_size]
@ -1509,6 +1548,7 @@ class EntryAdapters:
@staticmethod
def get_all_filenames_by_source(user: KhojUser, file_source: str):
check_valid_user(user)
return (
Entry.objects.filter(user=user, file_source=file_source)
.distinct("file_path")
@ -1517,6 +1557,7 @@ class EntryAdapters:
@staticmethod
def get_size_of_indexed_data_in_mb(user: KhojUser):
check_valid_user(user)
entries = Entry.objects.filter(user=user).iterator()
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
return total_size / 1024 / 1024
@ -1536,6 +1577,9 @@ class EntryAdapters:
if agent != None:
owner_filter |= Q(agent=agent)
if owner_filter == Q():
return Entry.objects.none()
if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
return Entry.objects.filter(owner_filter)
@ -1611,10 +1655,12 @@ class EntryAdapters:
@staticmethod
def get_unique_file_types(user: KhojUser):
check_valid_user(user)
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
@staticmethod
def get_unique_file_sources(user: KhojUser):
check_valid_user(user)
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all()

View file

@ -18,7 +18,7 @@ class DocxToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
@ -35,13 +35,13 @@ class DocxToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.DOCX,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

View file

@ -48,7 +48,7 @@ class GithubToEntries(TextToEntries):
else:
return
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
if self.config.pat_token is None or self.config.pat_token == "":
logger.error(f"Github PAT token is not set. Skipping github content")
raise ValueError("Github PAT token is not set. Skipping github content")
@ -101,12 +101,12 @@ class GithubToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.GITHUB,
DbEntry.EntrySource.GITHUB,
key="compiled",
logger=logger,
user=user,
)
return num_new_embeddings, num_deleted_embeddings

View file

@ -18,7 +18,7 @@ class ImageToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
@ -35,13 +35,13 @@ class ImageToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.IMAGE,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

View file

@ -19,7 +19,7 @@ class MarkdownToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
@ -37,13 +37,13 @@ class MarkdownToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.MARKDOWN,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

View file

@ -79,7 +79,7 @@ class NotionToEntries(TextToEntries):
self.body_params = {"page_size": 100}
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
current_entries = []
# Get all pages
@ -248,12 +248,12 @@ class NotionToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.NOTION,
DbEntry.EntrySource.NOTION,
key="compiled",
logger=logger,
user=user,
)
return num_new_embeddings, num_deleted_embeddings

View file

@ -20,7 +20,7 @@ class OrgToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
@ -36,13 +36,13 @@ class OrgToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.ORG,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

View file

@ -19,7 +19,7 @@ class PdfToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
@ -36,13 +36,13 @@ class PdfToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.PDF,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

View file

@ -20,7 +20,7 @@ class PlaintextToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
@ -36,13 +36,13 @@ class PlaintextToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.PLAINTEXT,
DbEntry.EntrySource.COMPUTER,
key="compiled",
logger=logger,
deletion_filenames=deletion_file_names,
user=user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

View file

@ -31,7 +31,7 @@ class TextToEntries(ABC):
self.date_filter = DateFilter()
@abstractmethod
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
...
@staticmethod
@ -114,13 +114,13 @@ class TextToEntries(ABC):
def update_embeddings(
self,
user: KhojUser,
current_entries: List[Entry],
file_type: str,
file_source: str,
key="compiled",
logger: logging.Logger = None,
deletion_filenames: Set[str] = None,
user: KhojUser = None,
regenerate: bool = False,
file_to_text_map: dict[str, str] = None,
):

View file

@ -212,7 +212,7 @@ def update(
logger.warning(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
try:
initialize_content(regenerate=force, search_type=t, user=user)
initialize_content(user=user, regenerate=force, search_type=t)
except Exception as e:
error_msg = f"🚨 Failed to update server via API: {e}"
logger.error(error_msg, exc_info=True)

View file

@ -239,7 +239,7 @@ async def set_content_notion(
if updated_config.token:
# Trigger an async job to configure_content. Let it run without blocking the response.
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user)
background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion)
update_telemetry_state(
request=request,
@ -512,10 +512,10 @@ async def indexer(
success = await loop.run_in_executor(
None,
configure_content,
user,
indexer_input.model_dump(),
regenerate,
t,
user,
)
if not success:
raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index")

View file

@ -703,7 +703,7 @@ async def generate_summary_from_files(
if await EntryAdapters.aagent_has_entries(agent):
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
if len(file_names) > 0:
file_objects = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent)
file_objects = await FileObjectAdapters.aget_file_objects_by_name(None, file_names.pop(), agent)
if (file_objects and len(file_objects) == 0 and not query_files) or (not file_objects and not query_files):
response_log = "Sorry, I couldn't find anything to summarize."
@ -1975,10 +1975,10 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
def configure_content(
user: KhojUser,
files: Optional[dict[str, dict[str, str]]],
regenerate: bool = False,
t: Optional[state.SearchType] = state.SearchType.All,
user: KhojUser = None,
) -> bool:
success = True
if t == None:

View file

@ -80,6 +80,6 @@ async def notion_auth_callback(request: Request, background_tasks: BackgroundTas
notion_redirect = str(request.app.url_path_for("config_page"))
# Trigger an async job to configure_content. Let it run without blocking the response.
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user)
background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion)
return RedirectResponse(notion_redirect)

View file

@ -208,7 +208,7 @@ def setup(
text_to_entries: Type[TextToEntries],
files: dict[str, str],
regenerate: bool,
user: KhojUser = None,
user: KhojUser,
config=None,
) -> Tuple[int, int]:
if config:

View file

@ -8,6 +8,7 @@ from bs4 import BeautifulSoup
from magika import Magika
from khoj.database.models import (
KhojUser,
LocalMarkdownConfig,
LocalOrgConfig,
LocalPdfConfig,
@ -21,7 +22,7 @@ logger = logging.getLogger(__name__)
magika = Magika()
def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict:
def collect_files(user: KhojUser, search_type: Optional[SearchType] = SearchType.All) -> dict:
files: dict[str, dict] = {"docx": {}, "image": {}}
if search_type == SearchType.All or search_type == SearchType.Org:

View file

@ -304,7 +304,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
# Index Markdown Content for Search
all_files = fs_syncer.collect_files(user=user)
success = configure_content(all_files, user=user)
configure_content(user, all_files)
# Initialize Processor from Config
if os.getenv("OPENAI_API_KEY"):
@ -381,7 +381,7 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
)
all_files = fs_syncer.collect_files(user=default_user2)
configure_content(all_files, user=default_user2)
configure_content(default_user2, all_files)
# Initialize Processor from Config
ChatModelOptionsFactory(
@ -432,7 +432,7 @@ def pdf_configured_user1(default_user: KhojUser):
)
# Index Markdown Content for Search
all_files = fs_syncer.collect_files(user=default_user)
success = configure_content(all_files, user=default_user)
configure_content(default_user, all_files)
@pytest.fixture(scope="function")

View file

@ -253,11 +253,11 @@ def test_regenerate_with_github_fails_without_pat(client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_get_configured_types_via_api(client, sample_org_data):
def test_get_configured_types_via_api(client, sample_org_data, default_user3: KhojUser):
# Act
text_search.setup(OrgToEntries, sample_org_data, regenerate=False)
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user3)
enabled_types = EntryAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
enabled_types = EntryAdapters.get_unique_file_types(user=default_user3).all().values_list("file_type", flat=True)
# Assert
assert list(enabled_types) == ["org"]