Add Ability to Summarize Documents (#800)

* Uses entire file text and summarizer model to generate document summary.
* Uses the contents of the user's query to create a tailored summary.
* Integrates with File Filters #788 for a better UX.
This commit is contained in:
Raghav Tirumale 2024-06-18 10:01:07 -04:00 committed by GitHub
parent 677d49d438
commit d4e5c95711
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 791 additions and 85 deletions

View file

@ -28,6 +28,7 @@ from khoj.database.models import (
ClientApplication,
Conversation,
Entry,
FileObject,
GithubConfig,
GithubRepoConfig,
GoogleUser,
@ -731,7 +732,7 @@ class ConversationAdapters:
if server_chat_settings is None or (
server_chat_settings.summarizer_model is None and server_chat_settings.default_model is None
):
return await ChatModelOptions.objects.filter().afirst()
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
return server_chat_settings.summarizer_model or server_chat_settings.default_model
@staticmethod
@ -846,6 +847,58 @@ class ConversationAdapters:
return await TextToImageModelConfig.objects.filter().afirst()
class FileObjectAdapters:
@staticmethod
def update_raw_text(file_object: FileObject, new_raw_text: str):
file_object.raw_text = new_raw_text
file_object.save()
@staticmethod
def create_file_object(user: KhojUser, file_name: str, raw_text: str):
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod
def get_file_objects_by_name(user: KhojUser, file_name: str):
return FileObject.objects.filter(user=user, file_name=file_name).first()
@staticmethod
def get_all_file_objects(user: KhojUser):
return FileObject.objects.filter(user=user).all()
@staticmethod
def delete_file_object_by_name(user: KhojUser, file_name: str):
return FileObject.objects.filter(user=user, file_name=file_name).delete()
@staticmethod
def delete_all_file_objects(user: KhojUser):
return FileObject.objects.filter(user=user).delete()
@staticmethod
async def async_update_raw_text(file_object: FileObject, new_raw_text: str):
file_object.raw_text = new_raw_text
await file_object.asave()
@staticmethod
async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str):
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod
async def async_get_file_objects_by_name(user: KhojUser, file_name: str):
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name))
@staticmethod
async def async_get_all_file_objects(user: KhojUser):
return await sync_to_async(list)(FileObject.objects.filter(user=user))
@staticmethod
async def async_delete_file_object_by_name(user: KhojUser, file_name: str):
return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
@staticmethod
async def async_delete_all_file_objects(user: KhojUser):
return await FileObject.objects.filter(user=user).adelete()
class EntryAdapters:
word_filer = WordFilter()
file_filter = FileFilter()

View file

@ -0,0 +1,37 @@
# Generated by Django 4.2.11 on 2024-06-14 06:13
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0044_conversation_file_filters"),
]
operations = [
migrations.CreateModel(
name="FileObject",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("file_name", models.CharField(blank=True, default=None, max_length=400, null=True)),
("raw_text", models.TextField()),
(
"user",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"abstract": False,
},
),
]

View file

@ -326,6 +326,13 @@ class Entry(BaseModel):
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
class FileObject(BaseModel):
# Same as Entry but raw will be a much larger string
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
raw_text = models.TextField()
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
class EntryDates(BaseModel):
date = models.DateField()
entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")

View file

@ -2145,7 +2145,7 @@ To get started, just start typing below. You can also type / to see a list of co
<div style="border-top: 1px solid black; ">
<div style="display: flex; align-items: center; justify-content: space-between; margin-bottom: 5px; margin-top: 5px;">
<p style="margin: 0;">Files</p>
<svg id="file-toggle-button" class="file-toggle-button" style="width:20px; height:20px; position: relative; top: 2px" viewBox="0 0 40 40" fill="#000000" xmlns="http://www.w3.org/2000/svg">
<svg class="file-toggle-button" style="width:20px; height:20px; position: relative; top: 2px" viewBox="0 0 40 40" fill="#000000" xmlns="http://www.w3.org/2000/svg">
<path d="M16 0c-8.836 0-16 7.163-16 16s7.163 16 16 16c8.837 0 16-7.163 16-16s-7.163-16-16-16zM16 30.032c-7.72 0-14-6.312-14-14.032s6.28-14 14-14 14 6.28 14 14-6.28 14.032-14 14.032zM23 15h-6v-6c0-0.552-0.448-1-1-1s-1 0.448-1 1v6h-6c-0.552 0-1 0.448-1 1s0.448 1 1 1h6v6c0 0.552 0.448 1 1 1s1-0.448 1-1v-6h6c0.552 0 1-0.448 1-1s-0.448-1-1-1z"></path>
</svg>
</div>

View file

@ -33,7 +33,7 @@ class MarkdownToEntries(TextToEntries):
max_tokens = 256
# Extract Entries from specified Markdown files
with timer("Extract entries from specified Markdown files", logger):
current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
file_to_text_map, current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@ -50,27 +50,30 @@ class MarkdownToEntries(TextToEntries):
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_markdown_entries(markdown_files, max_tokens=256) -> List[Entry]:
def extract_markdown_entries(markdown_files, max_tokens=256) -> Tuple[Dict, List[Entry]]:
"Extract entries by heading from specified Markdown files"
entries: List[str] = []
entry_to_file_map: List[Tuple[str, str]] = []
file_to_text_map = dict()
for markdown_file in markdown_files:
try:
markdown_content = markdown_files[markdown_file]
entries, entry_to_file_map = MarkdownToEntries.process_single_markdown_file(
markdown_content, markdown_file, entries, entry_to_file_map, max_tokens
)
file_to_text_map[markdown_file] = markdown_content
except Exception as e:
logger.error(
f"Unable to process file: {markdown_file}. This file will not be indexed.\n{e}", exc_info=True
)
return MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map))
return file_to_text_map, MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map))
@staticmethod
def process_single_markdown_file(

View file

@ -33,7 +33,7 @@ class OrgToEntries(TextToEntries):
# Extract Entries from specified Org files
max_tokens = 256
with timer("Extract entries from specified Org files", logger):
current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
file_to_text_map, current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=max_tokens)
@ -49,6 +49,7 @@ class OrgToEntries(TextToEntries):
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)
return num_new_embeddings, num_deleted_embeddings
@ -56,26 +57,32 @@ class OrgToEntries(TextToEntries):
@staticmethod
def extract_org_entries(
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=256
) -> List[Entry]:
) -> Tuple[Dict, List[Entry]]:
"Extract entries from specified Org files"
entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
return OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map, index_heading_entries)
file_to_text_map, entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
return file_to_text_map, OrgToEntries.convert_org_nodes_to_entries(
entries, entry_to_file_map, index_heading_entries
)
@staticmethod
def extract_org_nodes(org_files: dict[str, str], max_tokens) -> Tuple[List[List[Orgnode]], Dict[Orgnode, str]]:
def extract_org_nodes(
org_files: dict[str, str], max_tokens
) -> Tuple[Dict, List[List[Orgnode]], Dict[Orgnode, str]]:
"Extract org nodes from specified org files"
entries: List[List[Orgnode]] = []
entry_to_file_map: List[Tuple[Orgnode, str]] = []
file_to_text_map = {}
for org_file in org_files:
try:
org_content = org_files[org_file]
entries, entry_to_file_map = OrgToEntries.process_single_org_file(
org_content, org_file, entries, entry_to_file_map, max_tokens
)
file_to_text_map[org_file] = org_content
except Exception as e:
logger.error(f"Unable to process file: {org_file}. Skipped indexing it.\nError; {e}", exc_info=True)
return entries, dict(entry_to_file_map)
return file_to_text_map, entries, dict(entry_to_file_map)
@staticmethod
def process_single_org_file(

View file

@ -2,10 +2,12 @@ import base64
import logging
import os
from datetime import datetime
from typing import List, Tuple
from typing import Dict, List, Tuple
from langchain_community.document_loaders import PyMuPDFLoader
# importing FileObjectAdapter so that we can add new files and debug file object db.
# from khoj.database.adapters import FileObjectAdapters
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries
@ -33,7 +35,7 @@ class PdfToEntries(TextToEntries):
# Extract Entries from specified Pdf files
with timer("Extract entries from specified PDF files", logger):
current_entries = PdfToEntries.extract_pdf_entries(files)
file_to_text_map, current_entries = PdfToEntries.extract_pdf_entries(files)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@ -50,14 +52,15 @@ class PdfToEntries(TextToEntries):
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_pdf_entries(pdf_files) -> List[Entry]:
def extract_pdf_entries(pdf_files) -> Tuple[Dict, List[Entry]]: # important function
"""Extract entries by page from specified PDF files"""
file_to_text_map = dict()
entries: List[str] = []
entry_to_location_map: List[Tuple[str, str]] = []
for pdf_file in pdf_files:
@ -73,9 +76,14 @@ class PdfToEntries(TextToEntries):
pdf_entries_per_file = [page.page_content for page in loader.load()]
except ImportError:
loader = PyMuPDFLoader(f"{tmp_file}")
pdf_entries_per_file = [page.page_content for page in loader.load()]
entry_to_location_map += zip(pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file))
pdf_entries_per_file = [
page.page_content for page in loader.load()
] # page_content items list for a given pdf.
entry_to_location_map += zip(
pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file)
) # this is an indexed map of pdf_entries for the pdf.
entries.extend(pdf_entries_per_file)
file_to_text_map[pdf_file] = pdf_entries_per_file
except Exception as e:
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
@ -83,7 +91,7 @@ class PdfToEntries(TextToEntries):
if os.path.exists(f"{tmp_file}"):
os.remove(f"{tmp_file}")
return PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
return file_to_text_map, PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
@staticmethod
def convert_pdf_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:

View file

@ -32,7 +32,7 @@ class PlaintextToEntries(TextToEntries):
# Extract Entries from specified plaintext files
with timer("Extract entries from specified Plaintext files", logger):
current_entries = PlaintextToEntries.extract_plaintext_entries(files)
file_to_text_map, current_entries = PlaintextToEntries.extract_plaintext_entries(files)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@ -49,6 +49,7 @@ class PlaintextToEntries(TextToEntries):
deletion_filenames=deletion_file_names,
user=user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)
return num_new_embeddings, num_deleted_embeddings
@ -63,21 +64,23 @@ class PlaintextToEntries(TextToEntries):
return soup.get_text(strip=True, separator="\n")
@staticmethod
def extract_plaintext_entries(text_files: Dict[str, str]) -> List[Entry]:
def extract_plaintext_entries(text_files: Dict[str, str]) -> Tuple[Dict, List[Entry]]:
entries: List[str] = []
entry_to_file_map: List[Tuple[str, str]] = []
file_to_text_map = dict()
for text_file in text_files:
try:
text_content = text_files[text_file]
entries, entry_to_file_map = PlaintextToEntries.process_single_plaintext_file(
text_content, text_file, entries, entry_to_file_map
)
file_to_text_map[text_file] = text_content
except Exception as e:
logger.warning(f"Unable to read file: {text_file} as plaintext. Skipping file.")
logger.warning(e, exc_info=True)
# Extract Entries from specified plaintext files
return PlaintextToEntries.convert_text_files_to_entries(entries, dict(entry_to_file_map))
return file_to_text_map, PlaintextToEntries.convert_text_files_to_entries(entries, dict(entry_to_file_map))
@staticmethod
def process_single_plaintext_file(

View file

@ -9,7 +9,11 @@ from typing import Any, Callable, List, Set, Tuple
from langchain.text_splitter import RecursiveCharacterTextSplitter
from tqdm import tqdm
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
from khoj.database.adapters import (
EntryAdapters,
FileObjectAdapters,
get_user_search_model_or_default,
)
from khoj.database.models import Entry as DbEntry
from khoj.database.models import EntryDates, KhojUser
from khoj.search_filter.date_filter import DateFilter
@ -120,6 +124,7 @@ class TextToEntries(ABC):
deletion_filenames: Set[str] = None,
user: KhojUser = None,
regenerate: bool = False,
file_to_text_map: dict[str, List[str]] = None,
):
with timer("Constructed current entry hashes in", logger):
hashes_by_file = dict[str, set[str]]()
@ -186,6 +191,18 @@ class TextToEntries(ABC):
logger.error(f"Error adding entries to database:\n{batch_indexing_error}\n---\n{e}", exc_info=True)
logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
if file_to_text_map:
# get the list of file_names using added_entries
filenames_to_update = [entry.file_path for entry in added_entries]
# for each file_name in filenames_to_update, try getting the file object and updating raw_text and if it fails create a new file object
for file_name in filenames_to_update:
raw_text = " ".join(file_to_text_map[file_name])
file_object = FileObjectAdapters.get_file_objects_by_name(user, file_name)
if file_object:
FileObjectAdapters.update_raw_text(file_object, raw_text)
else:
FileObjectAdapters.create_file_object(user, file_name, raw_text)
new_dates = []
with timer("Indexed dates from added entries in", logger):
for added_entry in added_entries:
@ -210,6 +227,7 @@ class TextToEntries(ABC):
for file_path in deletion_filenames:
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_entries += deleted_count
FileObjectAdapters.delete_file_object_by_name(user, file_path)
return len(added_entries), num_deleted_entries

View file

@ -321,6 +321,27 @@ Collate only relevant information from the website to answer the target query.
""".strip()
)
system_prompt_extract_relevant_summary = """As a professional analyst, create a comprehensive report of the most relevant information from the document in response to a user's query. The text provided is directly from within the document. The report you create should be multiple paragraphs, and it should represent the content of the document. Tell the user exactly what the document says in response to their query, while adhering to these guidelines:
1. Answer the user's query as specifically as possible. Include many supporting details from the document.
2. Craft a report that is detailed, thorough, in-depth, and complex, while maintaining clarity.
3. Rely strictly on the provided text, without including external information.
4. Format the report in multiple paragraphs with a clear structure.
5. Be as specific as possible in your answer to the user's query.
6. Reproduce as much of the provided text as possible, while maintaining readability.
""".strip()
extract_relevant_summary = PromptTemplate.from_template(
"""
Target Query: {query}
Document Contents:
{corpus}
Collate only relevant information from the document to answer the target query.
""".strip()
)
pick_relevant_output_mode = PromptTemplate.from_template(
"""
You are Khoj, an excellent analyst for selecting the correct way to respond to a user's query. You have access to a limited set of modes for your response. You can only use one of these modes.

View file

@ -16,6 +16,7 @@ from websockets import ConnectionClosedOK
from khoj.database.adapters import (
ConversationAdapters,
EntryAdapters,
FileObjectAdapters,
PublicConversationAdapters,
aget_user_name,
)
@ -42,6 +43,7 @@ from khoj.routers.helpers import (
aget_relevant_output_modes,
construct_automation_created_message,
create_automation,
extract_relevant_summary,
get_conversation_command,
is_query_empty,
is_ready_to_chat,
@ -586,6 +588,51 @@ async def websocket_endpoint(
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
if ConversationCommand.Summarize in conversation_commands:
file_filters = conversation.file_filters
response_log = ""
if len(file_filters) == 0:
response_log = "No files selected for summarization. Please add files using the section on the left."
await send_complete_llm_response(response_log)
elif len(file_filters) > 1:
response_log = "Only one file can be selected for summarization."
await send_complete_llm_response(response_log)
else:
try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
await send_complete_llm_response(response_log)
continue
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
await send_status_update(f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}")
response = await extract_relevant_summary(q, contextual_data)
response_log = str(response)
await send_complete_llm_response(response_log)
except Exception as e:
response_log = "Error summarizing file."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
await send_complete_llm_response(response_log)
await sync_to_async(save_to_conversation_log)(
q,
response_log,
user,
meta_log,
user_message_time,
intent_type="summarize",
client_application=websocket.user.client_app,
conversation_id=conversation_id,
)
update_telemetry_state(
request=websocket,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
continue
custom_filters = []
if conversation_commands == [ConversationCommand.Help]:
if not q:
@ -828,6 +875,49 @@ async def chat(
_custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online)
conversation = await ConversationAdapters.aget_conversation_by_user(user, conversation_id=conversation_id)
conversation_id = conversation.id if conversation else None
if ConversationCommand.Summarize in conversation_commands:
file_filters = conversation.file_filters
llm_response = ""
if len(file_filters) == 0:
llm_response = "No files selected for summarization. Please add files using the section on the left."
elif len(file_filters) > 1:
llm_response = "Only one file can be selected for summarization."
else:
try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
llm_response = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
contextual_data = " ".join([file.raw_text for file in file_object])
summarizeStr = "/" + ConversationCommand.Summarize
if q.strip() == summarizeStr:
q = "Create a general summary of the file"
response = await extract_relevant_summary(q, contextual_data)
llm_response = str(response)
except Exception as e:
logger.error(f"Error summarizing file for {user.email}: {e}")
llm_response = "Error summarizing file."
await sync_to_async(save_to_conversation_log)(
q,
llm_response,
user,
conversation.conversation_log,
user_message_time,
intent_type="summarize",
client_application=request.user.client_app,
conversation_id=conversation_id,
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__,
)
return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
conversation = await ConversationAdapters.aget_conversation_by_user(
user, request.user.client_app, conversation_id, title
)

View file

@ -200,6 +200,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.Image
elif query.startswith("/automated_task"):
return ConversationCommand.AutomatedTask
elif query.startswith("/summarize"):
return ConversationCommand.Summarize
# If no relevant notes found for the given query
elif not any_references:
return ConversationCommand.General
@ -418,7 +420,30 @@ async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
prompts.system_prompt_extract_relevant_information,
chat_model_option=summarizer_model,
)
return response.strip()
async def extract_relevant_summary(q: str, corpus: str) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
"""
if is_none_or_empty(corpus) or is_none_or_empty(q):
return None
extract_relevant_information = prompts.extract_relevant_summary.format(
query=q,
corpus=corpus.strip(),
)
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_summary,
chat_model_option=summarizer_model,
)
return response.strip()

View file

@ -307,6 +307,7 @@ class ConversationCommand(str, Enum):
Text = "text"
Automation = "automation"
AutomatedTask = "automated_task"
Summarize = "summarize"
command_descriptions = {
@ -318,6 +319,7 @@ command_descriptions = {
ConversationCommand.Image: "Generate images by describing your imagination in words.",
ConversationCommand.Automation: "Automatically run your query at a specified time or interval.",
ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation",
ConversationCommand.Summarize: "Create an appropriate summary using provided documents.",
}
tool_descriptions_for_llm = {
@ -326,6 +328,7 @@ tool_descriptions_for_llm = {
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
ConversationCommand.Summarize: "To create a summary of the document provided by the user.",
}
mode_descriptions_for_llm = {

View file

@ -19,6 +19,7 @@ from khoj.database.models import (
KhojUser,
LocalMarkdownConfig,
LocalOrgConfig,
LocalPdfConfig,
LocalPlaintextConfig,
)
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
@ -415,6 +416,18 @@ def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser):
return LocalOrgConfig.objects.filter(user=default_user).first()
@pytest.fixture(scope="function")
def pdf_configured_user1(default_user: KhojUser):
LocalPdfConfig.objects.create(
input_files=None,
input_filter=["tests/data/pdf/singlepage.pdf"],
user=default_user,
)
# Index Markdown Content for Search
all_files = fs_syncer.collect_files(user=default_user)
success = configure_content(all_files, user=default_user)
@pytest.fixture(scope="function")
def sample_org_data():
return get_sample_data("org")

View file

@ -51,7 +51,9 @@ class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
tokenizer = None
chat_model = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
model_type = "offline"
openai_config = factory.SubFactory(OpenAIProcessorConversationConfigFactory)
openai_config = factory.LazyAttribute(
lambda obj: OpenAIProcessorConversationConfigFactory() if os.getenv("OPENAI_API_KEY") else None
)
class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):

View file

@ -23,13 +23,14 @@ def test_extract_markdown_with_no_headings(tmp_path):
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
assert len(entries) == 1
assert len(entries) == 2
assert len(entries[1]) == 1
# Ensure raw entry with no headings do not get heading prefix prepended
assert not entries[0].raw.startswith("#")
assert not entries[1][0].raw.startswith("#")
# Ensure compiled entry has filename prepended as top level heading
assert entries[0].compiled.startswith(expected_heading)
assert entries[1][0].compiled.startswith(expected_heading)
# Ensure compiled entry also includes the file name
assert str(tmp_path) in entries[0].compiled
assert str(tmp_path) in entries[1][0].compiled
def test_extract_single_markdown_entry(tmp_path):
@ -48,7 +49,8 @@ def test_extract_single_markdown_entry(tmp_path):
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
assert len(entries) == 1
assert len(entries) == 2
assert len(entries[1]) == 1
def test_extract_multiple_markdown_entries(tmp_path):
@ -72,8 +74,9 @@ def test_extract_multiple_markdown_entries(tmp_path):
# Assert
assert len(entries) == 2
assert len(entries[1]) == 2
# Ensure entry compiled strings include the markdown files they originate from
assert all([tmp_path.stem in entry.compiled for entry in entries])
assert all([tmp_path.stem in entry.compiled for entry in entries[1]])
def test_extract_entries_with_different_level_headings(tmp_path):
@ -94,8 +97,9 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Assert
assert len(entries) == 2
assert entries[0].raw == "# Heading 1\n## Sub-Heading 1.1", "Ensure entry includes heading ancestory"
assert entries[1].raw == "# Heading 2\n"
assert len(entries[1]) == 2
assert entries[1][0].raw == "# Heading 1\n## Sub-Heading 1.1", "Ensure entry includes heading ancestory"
assert entries[1][1].raw == "# Heading 2\n"
def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
@ -116,10 +120,11 @@ def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
assert len(entries) == 3
assert entries[0].raw == "# Heading 1\n#### Sub-Heading 1.1", "Ensure entry includes heading ancestory"
assert entries[1].raw == "# Heading 1\n## Sub-Heading 1.2", "Ensure entry includes heading ancestory"
assert entries[2].raw == "# Heading 2\n"
assert len(entries) == 2
assert len(entries[1]) == 3
assert entries[1][0].raw == "# Heading 1\n#### Sub-Heading 1.1", "Ensure entry includes heading ancestory"
assert entries[1][1].raw == "# Heading 1\n## Sub-Heading 1.2", "Ensure entry includes heading ancestory"
assert entries[1][2].raw == "# Heading 2\n"
def test_extract_entries_with_text_before_headings(tmp_path):
@ -141,10 +146,13 @@ body line 2
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
assert len(entries) == 3
assert entries[0].raw == "\nText before headings"
assert entries[1].raw == "# Heading 1\nbody line 1"
assert entries[2].raw == "# Heading 1\n## Heading 2\nbody line 2\n", "Ensure raw entry includes heading ancestory"
assert len(entries) == 2
assert len(entries[1]) == 3
assert entries[1][0].raw == "\nText before headings"
assert entries[1][1].raw == "# Heading 1\nbody line 1"
assert (
entries[1][2].raw == "# Heading 1\n## Heading 2\nbody line 2\n"
), "Ensure raw entry includes heading ancestory"
def test_parse_markdown_file_into_single_entry_if_small(tmp_path):
@ -165,8 +173,9 @@ body line 1.1
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
# Assert
assert len(entries) == 1
assert entries[0].raw == entry
assert len(entries) == 2
assert len(entries[1]) == 1
assert entries[1][0].raw == entry
def test_parse_markdown_entry_with_children_as_single_entry_if_small(tmp_path):
@ -191,13 +200,14 @@ longer body line 2.1
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
# Assert
assert len(entries) == 3
assert len(entries) == 2
assert len(entries[1]) == 3
assert (
entries[0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1"
entries[1][0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1"
), "First entry includes children headings"
assert entries[1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings"
assert entries[1][1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings"
assert (
entries[2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n"
entries[1][2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n"
), "Third entry is second entries child heading"

View file

@ -6,7 +6,7 @@ import pytest
from faker import Faker
from freezegun import freeze_time
from khoj.database.models import Agent, KhojUser
from khoj.database.models import Agent, Entry, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_information_sources
@ -305,6 +305,200 @@ def test_answer_not_known_using_notes_command(client_offline_chat, default_user2
assert response_message == prompts.no_notes_found.format()
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_one_file(client_offline_chat, default_user2: KhojUser):
message_list = []
conversation = create_conversation(message_list, default_user2)
# post "Xi Li.markdown" file to the file filters
file_list = (
Entry.objects.filter(user=default_user2, file_source="computer")
.distinct("file_path")
.values_list("file_path", flat=True)
)
# pick the file that has "Xi Li.markdown" in the name
summarization_file = ""
for file in file_list:
if "Birthday Gift for Xiu turning 4.markdown" in file:
summarization_file = file
break
assert summarization_file != ""
response = client_offline_chat.post(
"api/chat/conversation/file-filters",
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
)
query = urllib.parse.quote("/summarize")
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left."
assert response_message != "Only one file can be selected for summarization."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_extra_text(client_offline_chat, default_user2: KhojUser):
message_list = []
conversation = create_conversation(message_list, default_user2)
# post "Xi Li.markdown" file to the file filters
file_list = (
Entry.objects.filter(user=default_user2, file_source="computer")
.distinct("file_path")
.values_list("file_path", flat=True)
)
# pick the file that has "Xi Li.markdown" in the name
summarization_file = ""
for file in file_list:
if "Birthday Gift for Xiu turning 4.markdown" in file:
summarization_file = file
break
assert summarization_file != ""
response = client_offline_chat.post(
"api/chat/conversation/file-filters",
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
)
query = urllib.parse.quote("/summarize tell me about Xiu")
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left."
assert response_message != "Only one file can be selected for summarization."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_multiple_files(client_offline_chat, default_user2: KhojUser):
message_list = []
conversation = create_conversation(message_list, default_user2)
# post "Xi Li.markdown" file to the file filters
file_list = (
Entry.objects.filter(user=default_user2, file_source="computer")
.distinct("file_path")
.values_list("file_path", flat=True)
)
response = client_offline_chat.post(
"api/chat/conversation/file-filters", json={"filename": file_list[0], "conversation_id": str(conversation.id)}
)
response = client_offline_chat.post(
"api/chat/conversation/file-filters", json={"filename": file_list[1], "conversation_id": str(conversation.id)}
)
query = urllib.parse.quote("/summarize")
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message == "Only one file can be selected for summarization."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_no_files(client_offline_chat, default_user2: KhojUser):
message_list = []
conversation = create_conversation(message_list, default_user2)
query = urllib.parse.quote("/summarize")
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_different_conversation(client_offline_chat, default_user2: KhojUser):
message_list = []
conversation1 = create_conversation(message_list, default_user2)
conversation2 = create_conversation(message_list, default_user2)
file_list = (
Entry.objects.filter(user=default_user2, file_source="computer")
.distinct("file_path")
.values_list("file_path", flat=True)
)
summarization_file = ""
for file in file_list:
if "Birthday Gift for Xiu turning 4.markdown" in file:
summarization_file = file
break
assert summarization_file != ""
# add file filter to conversation 1.
response = client_offline_chat.post(
"api/chat/conversation/file-filters",
json={"filename": summarization_file, "conversation_id": str(conversation1.id)},
)
query = urllib.parse.quote("/summarize")
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation2.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
# now make sure that the file filter is still in conversation 1
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation1.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left."
assert response_message != "Only one file can be selected for summarization."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_nonexistant_file(client_offline_chat, default_user2: KhojUser):
message_list = []
conversation = create_conversation(message_list, default_user2)
# post "imaginary.markdown" file to the file filters
response = client_offline_chat.post(
"api/chat/conversation/file-filters",
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
)
query = urllib.parse.quote("/summarize")
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_diff_user_file(
client_offline_chat, default_user: KhojUser, pdf_configured_user1, default_user2: KhojUser
):
message_list = []
conversation = create_conversation(message_list, default_user2)
# Get the pdf file called singlepage.pdf
file_list = (
Entry.objects.filter(user=default_user, file_source="computer")
.distinct("file_path")
.values_list("file_path", flat=True)
)
summarization_file = ""
for file in file_list:
if "singlepage.pdf" in file:
summarization_file = file
break
assert summarization_file != ""
# add singlepage.pdf to the file filters
response = client_offline_chat.post(
"api/chat/conversation/file-filters",
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
)
query = urllib.parse.quote("/summarize")
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
@ -522,7 +716,7 @@ async def test_get_correct_tools_online(client_offline_chat):
user_query = "What's the weather in Patagonia this week?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
# Assert
tools = [tool.value for tool in tools]
@ -537,7 +731,7 @@ async def test_get_correct_tools_notes(client_offline_chat):
user_query = "Where did I go for my first battleship training?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
# Assert
tools = [tool.value for tool in tools]
@ -552,7 +746,7 @@ async def test_get_correct_tools_online_or_general_and_notes(client_offline_chat
user_query = "What's the highest point in Patagonia and have I been there?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
# Assert
tools = [tool.value for tool in tools]
@ -569,7 +763,7 @@ async def test_get_correct_tools_general(client_offline_chat):
user_query = "How many noble gases are there?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
# Assert
tools = [tool.value for tool in tools]
@ -593,7 +787,7 @@ async def test_get_correct_tools_with_chat_history(client_offline_chat, default_
chat_history = create_conversation(chat_log, default_user2)
# Act
tools = await aget_relevant_information_sources(user_query, chat_history)
tools = await aget_relevant_information_sources(user_query, chat_history, is_task=False)
# Assert
tools = [tool.value for tool in tools]

View file

@ -5,7 +5,7 @@ from urllib.parse import quote
import pytest
from freezegun import freeze_time
from khoj.database.models import Agent, KhojUser
from khoj.database.models import Agent, Entry, KhojUser, LocalPdfConfig
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_information_sources
@ -289,6 +289,198 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default
assert response_message == prompts.no_entries_found.format()
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_one_file(chat_client, default_user2: KhojUser):
message_list = []
conversation = create_conversation(message_list, default_user2)
# post "Xi Li.markdown" file to the file filters
file_list = (
Entry.objects.filter(user=default_user2, file_source="computer")
.distinct("file_path")
.values_list("file_path", flat=True)
)
# pick the file that has "Xi Li.markdown" in the name
summarization_file = ""
for file in file_list:
if "Birthday Gift for Xiu turning 4.markdown" in file:
summarization_file = file
break
assert summarization_file != ""
response = chat_client.post(
"api/chat/conversation/file-filters",
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
)
query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left."
assert response_message != "Only one file can be selected for summarization."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_extra_text(chat_client, default_user2: KhojUser):
message_list = []
conversation = create_conversation(message_list, default_user2)
# post "Xi Li.markdown" file to the file filters
file_list = (
Entry.objects.filter(user=default_user2, file_source="computer")
.distinct("file_path")
.values_list("file_path", flat=True)
)
# pick the file that has "Xi Li.markdown" in the name
summarization_file = ""
for file in file_list:
if "Birthday Gift for Xiu turning 4.markdown" in file:
summarization_file = file
break
assert summarization_file != ""
response = chat_client.post(
"api/chat/conversation/file-filters",
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
)
query = urllib.parse.quote("/summarize tell me about Xiu")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left."
assert response_message != "Only one file can be selected for summarization."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_multiple_files(chat_client, default_user2: KhojUser):
message_list = []
conversation = create_conversation(message_list, default_user2)
# post "Xi Li.markdown" file to the file filters
file_list = (
Entry.objects.filter(user=default_user2, file_source="computer")
.distinct("file_path")
.values_list("file_path", flat=True)
)
response = chat_client.post(
"api/chat/conversation/file-filters", json={"filename": file_list[0], "conversation_id": str(conversation.id)}
)
response = chat_client.post(
"api/chat/conversation/file-filters", json={"filename": file_list[1], "conversation_id": str(conversation.id)}
)
query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message == "Only one file can be selected for summarization."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_no_files(chat_client, default_user2: KhojUser):
message_list = []
conversation = create_conversation(message_list, default_user2)
query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_different_conversation(chat_client, default_user2: KhojUser):
message_list = []
conversation1 = create_conversation(message_list, default_user2)
conversation2 = create_conversation(message_list, default_user2)
file_list = (
Entry.objects.filter(user=default_user2, file_source="computer")
.distinct("file_path")
.values_list("file_path", flat=True)
)
summarization_file = ""
for file in file_list:
if "Birthday Gift for Xiu turning 4.markdown" in file:
summarization_file = file
break
assert summarization_file != ""
# add file filter to conversation 1.
response = chat_client.post(
"api/chat/conversation/file-filters",
json={"filename": summarization_file, "conversation_id": str(conversation1.id)},
)
query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation2.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
# now make sure that the file filter is still in conversation 1
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation1.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left."
assert response_message != "Only one file can be selected for summarization."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_nonexistant_file(chat_client, default_user2: KhojUser):
message_list = []
conversation = create_conversation(message_list, default_user2)
# post "imaginary.markdown" file to the file filters
response = chat_client.post(
"api/chat/conversation/file-filters",
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
)
query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_diff_user_file(chat_client, default_user: KhojUser, pdf_configured_user1, default_user2: KhojUser):
message_list = []
conversation = create_conversation(message_list, default_user2)
# Get the pdf file called singlepage.pdf
file_list = (
Entry.objects.filter(user=default_user, file_source="computer")
.distinct("file_path")
.values_list("file_path", flat=True)
)
summarization_file = ""
for file in file_list:
if "singlepage.pdf" in file:
summarization_file = file
break
assert summarization_file != ""
# add singlepage.pdf to the file filters
response = chat_client.post(
"api/chat/conversation/file-filters",
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
)
query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
@pytest.mark.django_db(transaction=True)

View file

@ -33,10 +33,12 @@ def test_configure_indexing_heading_only_entries(tmp_path):
# Assert
if index_heading_entries:
# Entry with empty body indexed when index_heading_entries set to True
assert len(entries) == 1
assert len(entries) == 2
assert len(entries[1]) == 1
else:
# Entry with empty body ignored when index_heading_entries set to False
assert is_none_or_empty(entries)
assert len(entries) == 2
assert is_none_or_empty(entries[1])
def test_entry_split_when_exceeds_max_tokens():
@ -55,9 +57,9 @@ def test_entry_split_when_exceeds_max_tokens():
# Act
# Extract Entries from specified Org files
entries = OrgToEntries.extract_org_entries(org_files=data)
assert len(entries) == 2
# Split each entry from specified Org files by max tokens
entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=5)
entries = TextToEntries.split_entries_by_max_tokens(entries[1], max_tokens=5)
# Assert
assert len(entries) == 2
@ -114,11 +116,12 @@ body line 1.1
# Act
# Extract Entries from specified Org files
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12)
for entry in extracted_entries:
assert len(extracted_entries) == 2
for entry in extracted_entries[1]:
entry.raw = clean(entry.raw)
# Assert
assert len(extracted_entries) == 1
assert len(extracted_entries[1]) == 1
assert entry.raw == expected_entry
@ -165,10 +168,11 @@ longer body line 2.1
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12)
# Assert
assert len(extracted_entries) == 3
assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings"
assert extracted_entries[1].compiled == second_expected_entry, "Second entry does not include children headings"
assert extracted_entries[2].compiled == third_expected_entry, "Third entry is second entries child heading"
assert len(extracted_entries) == 2
assert len(extracted_entries[1]) == 3
assert extracted_entries[1][0].compiled == first_expected_entry, "First entry includes children headings"
assert extracted_entries[1][1].compiled == second_expected_entry, "Second entry does not include children headings"
assert extracted_entries[1][2].compiled == third_expected_entry, "Third entry is second entries child heading"
def test_separate_sibling_org_entries_if_all_cannot_fit_in_token_limit(tmp_path):
@ -226,10 +230,11 @@ body line 3.1
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=30)
# Assert
assert len(extracted_entries) == 3
assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings"
assert extracted_entries[1].compiled == second_expected_entry, "Second entry includes children headings"
assert extracted_entries[2].compiled == third_expected_entry, "Third entry includes children headings"
assert len(extracted_entries) == 2
assert len(extracted_entries[1]) == 3
assert extracted_entries[1][0].compiled == first_expected_entry, "First entry includes children headings"
assert extracted_entries[1][1].compiled == second_expected_entry, "Second entry includes children headings"
assert extracted_entries[1][2].compiled == third_expected_entry, "Third entry includes children headings"
def test_entry_with_body_to_entry(tmp_path):
@ -251,7 +256,8 @@ def test_entry_with_body_to_entry(tmp_path):
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
# Assert
assert len(entries) == 1
assert len(entries) == 2
assert len(entries[1]) == 1
def test_file_with_entry_after_intro_text_to_entry(tmp_path):
@ -273,6 +279,7 @@ Intro text
# Assert
assert len(entries) == 2
assert len(entries[1]) == 2
def test_file_with_no_headings_to_entry(tmp_path):
@ -291,7 +298,8 @@ def test_file_with_no_headings_to_entry(tmp_path):
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
# Assert
assert len(entries) == 1
assert len(entries) == 2
assert len(entries[1]) == 1
def test_get_org_files(tmp_path):
@ -349,13 +357,14 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Act
# Extract Entries from specified Org files
entries = OrgToEntries.extract_org_entries(org_files=data, index_heading_entries=True, max_tokens=3)
for entry in entries:
assert len(entries) == 2
for entry in entries[1]:
entry.raw = clean(f"{entry.raw}")
# Assert
assert len(entries) == 2
assert entries[0].raw == "* Heading 1\n** Sub-Heading 1.1\n", "Ensure entry includes heading ancestory"
assert entries[1].raw == "* Heading 2\n"
assert len(entries[1]) == 2
assert entries[1][0].raw == "* Heading 1\n** Sub-Heading 1.1\n", "Ensure entry includes heading ancestory"
assert entries[1][1].raw == "* Heading 2\n"
# Helper Functions

View file

@ -17,7 +17,8 @@ def test_single_page_pdf_to_jsonl():
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
# Assert
assert len(entries) == 1
assert len(entries) == 2
assert len(entries[1]) == 1
def test_multi_page_pdf_to_jsonl():
@ -31,7 +32,8 @@ def test_multi_page_pdf_to_jsonl():
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
# Assert
assert len(entries) == 6
assert len(entries) == 2
assert len(entries[1]) == 6
def test_ocr_page_pdf_to_jsonl():
@ -43,9 +45,9 @@ def test_ocr_page_pdf_to_jsonl():
data = {"tests/data/pdf/ocr_samples.pdf": pdf_bytes}
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
assert len(entries) == 1
assert "playing on a strip of marsh" in entries[0].raw
assert len(entries) == 2
assert len(entries[1]) == 1
assert "playing on a strip of marsh" in entries[1][0].raw
def test_get_pdf_files(tmp_path):

View file

@ -25,15 +25,16 @@ def test_plaintext_file(tmp_path):
entries = PlaintextToEntries.extract_plaintext_entries(data)
# Convert each entry.file to absolute path to make them JSON serializable
for entry in entries:
for entry in entries[1]:
entry.file = str(Path(entry.file).absolute())
# Assert
assert len(entries) == 1
assert len(entries) == 2
assert len(entries[1]) == 1
# Ensure raw entry with no headings do not get heading prefix prepended
assert not entries[0].raw.startswith("#")
assert not entries[1][0].raw.startswith("#")
# Ensure compiled entry has filename prepended as top level heading
assert entries[0].compiled == f"{plaintextfile}\n{raw_entry}"
assert entries[1][0].compiled == f"{plaintextfile}\n{raw_entry}"
def test_get_plaintext_files(tmp_path):
@ -94,8 +95,9 @@ def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
entries = PlaintextToEntries.extract_plaintext_entries(extracted_plaintext_files)
# Assert
assert len(entries) == 1
assert "<div>" not in entries[0].raw
assert len(entries) == 2
assert len(entries[1]) == 1
assert "<div>" not in entries[1][0].raw
def test_large_plaintext_file_split_into_multiple_entries(tmp_path):
@ -113,13 +115,20 @@ def test_large_plaintext_file_split_into_multiple_entries(tmp_path):
# Act
# Extract Entries from specified plaintext files
normal_entries = PlaintextToEntries.extract_plaintext_entries(normal_data)
large_entries = PlaintextToEntries.extract_plaintext_entries(large_data)
# assert
assert len(normal_entries) == 2
assert len(large_entries) == 2
normal_entries = PlaintextToEntries.split_entries_by_max_tokens(
PlaintextToEntries.extract_plaintext_entries(normal_data),
normal_entries[1],
max_tokens=max_tokens,
raw_is_compiled=True,
)
large_entries = PlaintextToEntries.split_entries_by_max_tokens(
PlaintextToEntries.extract_plaintext_entries(large_data), max_tokens=max_tokens, raw_is_compiled=True
large_entries[1], max_tokens=max_tokens, raw_is_compiled=True
)
# Assert