From d4e5c957115f5f70474e93ec7bb40bf554b9daae Mon Sep 17 00:00:00 2001
From: Raghav Tirumale <62105787+MythicalCow@users.noreply.github.com>
Date: Tue, 18 Jun 2024 10:01:07 -0400
Subject: [PATCH] Add Ability to Summarize Documents (#800)
* Uses entire file text and summarizer model to generate document summary.
* Uses the contents of the user's query to create a tailored summary.
* Integrates with File Filters #788 for a better UX.
---
src/khoj/database/adapters/__init__.py | 55 ++++-
.../database/migrations/0045_fileobject.py | 37 ++++
src/khoj/database/models/__init__.py | 7 +
src/khoj/interface/web/chat.html | 2 +-
.../content/markdown/markdown_to_entries.py | 9 +-
.../content/org_mode/org_to_entries.py | 19 +-
.../processor/content/pdf/pdf_to_entries.py | 22 +-
.../content/plaintext/plaintext_to_entries.py | 9 +-
src/khoj/processor/content/text_to_entries.py | 20 +-
src/khoj/processor/conversation/prompts.py | 21 ++
src/khoj/routers/api_chat.py | 90 ++++++++
src/khoj/routers/helpers.py | 25 +++
src/khoj/utils/helpers.py | 3 +
tests/conftest.py | 13 ++
tests/helpers.py | 4 +-
tests/test_markdown_to_entries.py | 54 +++--
tests/test_offline_chat_director.py | 206 +++++++++++++++++-
tests/test_openai_chat_director.py | 194 ++++++++++++++++-
tests/test_org_to_entries.py | 49 +++--
tests/test_pdf_to_entries.py | 12 +-
tests/test_plaintext_to_entries.py | 25 ++-
21 files changed, 791 insertions(+), 85 deletions(-)
create mode 100644 src/khoj/database/migrations/0045_fileobject.py
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index e3d09b0a..007ba09a 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -28,6 +28,7 @@ from khoj.database.models import (
ClientApplication,
Conversation,
Entry,
+ FileObject,
GithubConfig,
GithubRepoConfig,
GoogleUser,
@@ -731,7 +732,7 @@ class ConversationAdapters:
if server_chat_settings is None or (
server_chat_settings.summarizer_model is None and server_chat_settings.default_model is None
):
- return await ChatModelOptions.objects.filter().afirst()
+ return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
return server_chat_settings.summarizer_model or server_chat_settings.default_model
@staticmethod
@@ -846,6 +847,58 @@ class ConversationAdapters:
return await TextToImageModelConfig.objects.filter().afirst()
+class FileObjectAdapters:
+ @staticmethod
+ def update_raw_text(file_object: FileObject, new_raw_text: str):
+ file_object.raw_text = new_raw_text
+ file_object.save()
+
+ @staticmethod
+ def create_file_object(user: KhojUser, file_name: str, raw_text: str):
+ return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
+
+ @staticmethod
+ def get_file_objects_by_name(user: KhojUser, file_name: str):
+ return FileObject.objects.filter(user=user, file_name=file_name).first()
+
+ @staticmethod
+ def get_all_file_objects(user: KhojUser):
+ return FileObject.objects.filter(user=user).all()
+
+ @staticmethod
+ def delete_file_object_by_name(user: KhojUser, file_name: str):
+ return FileObject.objects.filter(user=user, file_name=file_name).delete()
+
+ @staticmethod
+ def delete_all_file_objects(user: KhojUser):
+ return FileObject.objects.filter(user=user).delete()
+
+ @staticmethod
+ async def async_update_raw_text(file_object: FileObject, new_raw_text: str):
+ file_object.raw_text = new_raw_text
+ await file_object.asave()
+
+ @staticmethod
+ async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str):
+ return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
+
+ @staticmethod
+ async def async_get_file_objects_by_name(user: KhojUser, file_name: str):
+ return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name))
+
+ @staticmethod
+ async def async_get_all_file_objects(user: KhojUser):
+ return await sync_to_async(list)(FileObject.objects.filter(user=user))
+
+ @staticmethod
+ async def async_delete_file_object_by_name(user: KhojUser, file_name: str):
+ return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
+
+ @staticmethod
+ async def async_delete_all_file_objects(user: KhojUser):
+ return await FileObject.objects.filter(user=user).adelete()
+
+
class EntryAdapters:
word_filer = WordFilter()
file_filter = FileFilter()
diff --git a/src/khoj/database/migrations/0045_fileobject.py b/src/khoj/database/migrations/0045_fileobject.py
new file mode 100644
index 00000000..5c6c52ca
--- /dev/null
+++ b/src/khoj/database/migrations/0045_fileobject.py
@@ -0,0 +1,37 @@
+# Generated by Django 4.2.11 on 2024-06-14 06:13
+
+import django.db.models.deletion
+from django.conf import settings
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0044_conversation_file_filters"),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name="FileObject",
+ fields=[
+ ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
+ ("created_at", models.DateTimeField(auto_now_add=True)),
+ ("updated_at", models.DateTimeField(auto_now=True)),
+ ("file_name", models.CharField(blank=True, default=None, max_length=400, null=True)),
+ ("raw_text", models.TextField()),
+ (
+ "user",
+ models.ForeignKey(
+ blank=True,
+ default=None,
+ null=True,
+ on_delete=django.db.models.deletion.CASCADE,
+ to=settings.AUTH_USER_MODEL,
+ ),
+ ),
+ ],
+ options={
+ "abstract": False,
+ },
+ ),
+ ]
diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index edd262bd..92415f59 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -326,6 +326,13 @@ class Entry(BaseModel):
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
+class FileObject(BaseModel):
+ # Same as Entry but raw will be a much larger string
+ file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
+ raw_text = models.TextField()
+ user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
+
+
class EntryDates(BaseModel):
date = models.DateField()
entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")
diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html
index e31b99ac..639c0642 100644
--- a/src/khoj/interface/web/chat.html
+++ b/src/khoj/interface/web/chat.html
@@ -2145,7 +2145,7 @@ To get started, just start typing below. You can also type / to see a list of co
diff --git a/src/khoj/processor/content/markdown/markdown_to_entries.py b/src/khoj/processor/content/markdown/markdown_to_entries.py
index 73e5bf47..ae0bd822 100644
--- a/src/khoj/processor/content/markdown/markdown_to_entries.py
+++ b/src/khoj/processor/content/markdown/markdown_to_entries.py
@@ -33,7 +33,7 @@ class MarkdownToEntries(TextToEntries):
max_tokens = 256
# Extract Entries from specified Markdown files
with timer("Extract entries from specified Markdown files", logger):
- current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
+ file_to_text_map, current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@@ -50,27 +50,30 @@ class MarkdownToEntries(TextToEntries):
deletion_file_names,
user,
regenerate=regenerate,
+ file_to_text_map=file_to_text_map,
)
return num_new_embeddings, num_deleted_embeddings
@staticmethod
- def extract_markdown_entries(markdown_files, max_tokens=256) -> List[Entry]:
+ def extract_markdown_entries(markdown_files, max_tokens=256) -> Tuple[Dict, List[Entry]]:
"Extract entries by heading from specified Markdown files"
entries: List[str] = []
entry_to_file_map: List[Tuple[str, str]] = []
+ file_to_text_map = dict()
for markdown_file in markdown_files:
try:
markdown_content = markdown_files[markdown_file]
entries, entry_to_file_map = MarkdownToEntries.process_single_markdown_file(
markdown_content, markdown_file, entries, entry_to_file_map, max_tokens
)
+ file_to_text_map[markdown_file] = markdown_content
except Exception as e:
logger.error(
f"Unable to process file: {markdown_file}. This file will not be indexed.\n{e}", exc_info=True
)
- return MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map))
+ return file_to_text_map, MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map))
@staticmethod
def process_single_markdown_file(
diff --git a/src/khoj/processor/content/org_mode/org_to_entries.py b/src/khoj/processor/content/org_mode/org_to_entries.py
index c91c2070..af85a6bd 100644
--- a/src/khoj/processor/content/org_mode/org_to_entries.py
+++ b/src/khoj/processor/content/org_mode/org_to_entries.py
@@ -33,7 +33,7 @@ class OrgToEntries(TextToEntries):
# Extract Entries from specified Org files
max_tokens = 256
with timer("Extract entries from specified Org files", logger):
- current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
+ file_to_text_map, current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=max_tokens)
@@ -49,6 +49,7 @@ class OrgToEntries(TextToEntries):
deletion_file_names,
user,
regenerate=regenerate,
+ file_to_text_map=file_to_text_map,
)
return num_new_embeddings, num_deleted_embeddings
@@ -56,26 +57,32 @@ class OrgToEntries(TextToEntries):
@staticmethod
def extract_org_entries(
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=256
- ) -> List[Entry]:
+ ) -> Tuple[Dict, List[Entry]]:
"Extract entries from specified Org files"
- entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
- return OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map, index_heading_entries)
+ file_to_text_map, entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
+ return file_to_text_map, OrgToEntries.convert_org_nodes_to_entries(
+ entries, entry_to_file_map, index_heading_entries
+ )
@staticmethod
- def extract_org_nodes(org_files: dict[str, str], max_tokens) -> Tuple[List[List[Orgnode]], Dict[Orgnode, str]]:
+ def extract_org_nodes(
+ org_files: dict[str, str], max_tokens
+ ) -> Tuple[Dict, List[List[Orgnode]], Dict[Orgnode, str]]:
"Extract org nodes from specified org files"
entries: List[List[Orgnode]] = []
entry_to_file_map: List[Tuple[Orgnode, str]] = []
+ file_to_text_map = {}
for org_file in org_files:
try:
org_content = org_files[org_file]
entries, entry_to_file_map = OrgToEntries.process_single_org_file(
org_content, org_file, entries, entry_to_file_map, max_tokens
)
+ file_to_text_map[org_file] = org_content
except Exception as e:
logger.error(f"Unable to process file: {org_file}. Skipped indexing it.\nError; {e}", exc_info=True)
- return entries, dict(entry_to_file_map)
+ return file_to_text_map, entries, dict(entry_to_file_map)
@staticmethod
def process_single_org_file(
diff --git a/src/khoj/processor/content/pdf/pdf_to_entries.py b/src/khoj/processor/content/pdf/pdf_to_entries.py
index c59b305c..45ff7261 100644
--- a/src/khoj/processor/content/pdf/pdf_to_entries.py
+++ b/src/khoj/processor/content/pdf/pdf_to_entries.py
@@ -2,10 +2,12 @@ import base64
import logging
import os
from datetime import datetime
-from typing import List, Tuple
+from typing import Dict, List, Tuple
from langchain_community.document_loaders import PyMuPDFLoader
+# importing FileObjectAdapter so that we can add new files and debug file object db.
+# from khoj.database.adapters import FileObjectAdapters
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries
@@ -33,7 +35,7 @@ class PdfToEntries(TextToEntries):
# Extract Entries from specified Pdf files
with timer("Extract entries from specified PDF files", logger):
- current_entries = PdfToEntries.extract_pdf_entries(files)
+ file_to_text_map, current_entries = PdfToEntries.extract_pdf_entries(files)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@@ -50,14 +52,15 @@ class PdfToEntries(TextToEntries):
deletion_file_names,
user,
regenerate=regenerate,
+ file_to_text_map=file_to_text_map,
)
return num_new_embeddings, num_deleted_embeddings
@staticmethod
- def extract_pdf_entries(pdf_files) -> List[Entry]:
+ def extract_pdf_entries(pdf_files) -> Tuple[Dict, List[Entry]]: # important function
"""Extract entries by page from specified PDF files"""
-
+ file_to_text_map = dict()
entries: List[str] = []
entry_to_location_map: List[Tuple[str, str]] = []
for pdf_file in pdf_files:
@@ -73,9 +76,14 @@ class PdfToEntries(TextToEntries):
pdf_entries_per_file = [page.page_content for page in loader.load()]
except ImportError:
loader = PyMuPDFLoader(f"{tmp_file}")
- pdf_entries_per_file = [page.page_content for page in loader.load()]
- entry_to_location_map += zip(pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file))
+ pdf_entries_per_file = [
+ page.page_content for page in loader.load()
+ ] # page_content items list for a given pdf.
+ entry_to_location_map += zip(
+ pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file)
+ ) # this is an indexed map of pdf_entries for the pdf.
entries.extend(pdf_entries_per_file)
+ file_to_text_map[pdf_file] = pdf_entries_per_file
except Exception as e:
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
@@ -83,7 +91,7 @@ class PdfToEntries(TextToEntries):
if os.path.exists(f"{tmp_file}"):
os.remove(f"{tmp_file}")
- return PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
+ return file_to_text_map, PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
@staticmethod
def convert_pdf_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
diff --git a/src/khoj/processor/content/plaintext/plaintext_to_entries.py b/src/khoj/processor/content/plaintext/plaintext_to_entries.py
index c14bc359..2c994899 100644
--- a/src/khoj/processor/content/plaintext/plaintext_to_entries.py
+++ b/src/khoj/processor/content/plaintext/plaintext_to_entries.py
@@ -32,7 +32,7 @@ class PlaintextToEntries(TextToEntries):
# Extract Entries from specified plaintext files
with timer("Extract entries from specified Plaintext files", logger):
- current_entries = PlaintextToEntries.extract_plaintext_entries(files)
+ file_to_text_map, current_entries = PlaintextToEntries.extract_plaintext_entries(files)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@@ -49,6 +49,7 @@ class PlaintextToEntries(TextToEntries):
deletion_filenames=deletion_file_names,
user=user,
regenerate=regenerate,
+ file_to_text_map=file_to_text_map,
)
return num_new_embeddings, num_deleted_embeddings
@@ -63,21 +64,23 @@ class PlaintextToEntries(TextToEntries):
return soup.get_text(strip=True, separator="\n")
@staticmethod
- def extract_plaintext_entries(text_files: Dict[str, str]) -> List[Entry]:
+ def extract_plaintext_entries(text_files: Dict[str, str]) -> Tuple[Dict, List[Entry]]:
entries: List[str] = []
entry_to_file_map: List[Tuple[str, str]] = []
+ file_to_text_map = dict()
for text_file in text_files:
try:
text_content = text_files[text_file]
entries, entry_to_file_map = PlaintextToEntries.process_single_plaintext_file(
text_content, text_file, entries, entry_to_file_map
)
+ file_to_text_map[text_file] = text_content
except Exception as e:
logger.warning(f"Unable to read file: {text_file} as plaintext. Skipping file.")
logger.warning(e, exc_info=True)
# Extract Entries from specified plaintext files
- return PlaintextToEntries.convert_text_files_to_entries(entries, dict(entry_to_file_map))
+ return file_to_text_map, PlaintextToEntries.convert_text_files_to_entries(entries, dict(entry_to_file_map))
@staticmethod
def process_single_plaintext_file(
diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py
index 361d0220..49331d6b 100644
--- a/src/khoj/processor/content/text_to_entries.py
+++ b/src/khoj/processor/content/text_to_entries.py
@@ -9,7 +9,11 @@ from typing import Any, Callable, List, Set, Tuple
from langchain.text_splitter import RecursiveCharacterTextSplitter
from tqdm import tqdm
-from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
+from khoj.database.adapters import (
+ EntryAdapters,
+ FileObjectAdapters,
+ get_user_search_model_or_default,
+)
from khoj.database.models import Entry as DbEntry
from khoj.database.models import EntryDates, KhojUser
from khoj.search_filter.date_filter import DateFilter
@@ -120,6 +124,7 @@ class TextToEntries(ABC):
deletion_filenames: Set[str] = None,
user: KhojUser = None,
regenerate: bool = False,
+ file_to_text_map: dict[str, List[str]] = None,
):
with timer("Constructed current entry hashes in", logger):
hashes_by_file = dict[str, set[str]]()
@@ -186,6 +191,18 @@ class TextToEntries(ABC):
logger.error(f"Error adding entries to database:\n{batch_indexing_error}\n---\n{e}", exc_info=True)
logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
+ if file_to_text_map:
+ # get the list of file_names using added_entries
+ filenames_to_update = [entry.file_path for entry in added_entries]
+ # for each file_name in filenames_to_update, try getting the file object and updating raw_text and if it fails create a new file object
+ for file_name in filenames_to_update:
+ raw_text = " ".join(file_to_text_map[file_name])
+ file_object = FileObjectAdapters.get_file_objects_by_name(user, file_name)
+ if file_object:
+ FileObjectAdapters.update_raw_text(file_object, raw_text)
+ else:
+ FileObjectAdapters.create_file_object(user, file_name, raw_text)
+
new_dates = []
with timer("Indexed dates from added entries in", logger):
for added_entry in added_entries:
@@ -210,6 +227,7 @@ class TextToEntries(ABC):
for file_path in deletion_filenames:
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_entries += deleted_count
+ FileObjectAdapters.delete_file_object_by_name(user, file_path)
return len(added_entries), num_deleted_entries
diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py
index d0f5356d..a1c7dff1 100644
--- a/src/khoj/processor/conversation/prompts.py
+++ b/src/khoj/processor/conversation/prompts.py
@@ -321,6 +321,27 @@ Collate only relevant information from the website to answer the target query.
""".strip()
)
+system_prompt_extract_relevant_summary = """As a professional analyst, create a comprehensive report of the most relevant information from the document in response to a user's query. The text provided is directly from within the document. The report you create should be multiple paragraphs, and it should represent the content of the document. Tell the user exactly what the document says in response to their query, while adhering to these guidelines:
+
+1. Answer the user's query as specifically as possible. Include many supporting details from the document.
+2. Craft a report that is detailed, thorough, in-depth, and complex, while maintaining clarity.
+3. Rely strictly on the provided text, without including external information.
+4. Format the report in multiple paragraphs with a clear structure.
+5. Be as specific as possible in your answer to the user's query.
+6. Reproduce as much of the provided text as possible, while maintaining readability.
+""".strip()
+
+extract_relevant_summary = PromptTemplate.from_template(
+ """
+Target Query: {query}
+
+Document Contents:
+{corpus}
+
+Collate only relevant information from the document to answer the target query.
+""".strip()
+)
+
pick_relevant_output_mode = PromptTemplate.from_template(
"""
You are Khoj, an excellent analyst for selecting the correct way to respond to a user's query. You have access to a limited set of modes for your response. You can only use one of these modes.
diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py
index 1ddcd6f0..f88fba96 100644
--- a/src/khoj/routers/api_chat.py
+++ b/src/khoj/routers/api_chat.py
@@ -16,6 +16,7 @@ from websockets import ConnectionClosedOK
from khoj.database.adapters import (
ConversationAdapters,
EntryAdapters,
+ FileObjectAdapters,
PublicConversationAdapters,
aget_user_name,
)
@@ -42,6 +43,7 @@ from khoj.routers.helpers import (
aget_relevant_output_modes,
construct_automation_created_message,
create_automation,
+ extract_relevant_summary,
get_conversation_command,
is_query_empty,
is_ready_to_chat,
@@ -586,6 +588,51 @@ async def websocket_endpoint(
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
+ if ConversationCommand.Summarize in conversation_commands:
+ file_filters = conversation.file_filters
+ response_log = ""
+ if len(file_filters) == 0:
+ response_log = "No files selected for summarization. Please add files using the section on the left."
+ await send_complete_llm_response(response_log)
+ elif len(file_filters) > 1:
+ response_log = "Only one file can be selected for summarization."
+ await send_complete_llm_response(response_log)
+ else:
+ try:
+ file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
+ if len(file_object) == 0:
+ response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
+ await send_complete_llm_response(response_log)
+ continue
+ contextual_data = " ".join([file.raw_text for file in file_object])
+ if not q:
+ q = "Create a general summary of the file"
+ await send_status_update(f"**🧑🏾💻 Constructing Summary Using:** {file_object[0].file_name}")
+ response = await extract_relevant_summary(q, contextual_data)
+ response_log = str(response)
+ await send_complete_llm_response(response_log)
+ except Exception as e:
+ response_log = "Error summarizing file."
+ logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
+ await send_complete_llm_response(response_log)
+ await sync_to_async(save_to_conversation_log)(
+ q,
+ response_log,
+ user,
+ meta_log,
+ user_message_time,
+ intent_type="summarize",
+ client_application=websocket.user.client_app,
+ conversation_id=conversation_id,
+ )
+ update_telemetry_state(
+ request=websocket,
+ telemetry_type="api",
+ api="chat",
+ metadata={"conversation_command": conversation_commands[0].value},
+ )
+ continue
+
custom_filters = []
if conversation_commands == [ConversationCommand.Help]:
if not q:
@@ -828,6 +875,49 @@ async def chat(
_custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online)
+ conversation = await ConversationAdapters.aget_conversation_by_user(user, conversation_id=conversation_id)
+ conversation_id = conversation.id if conversation else None
+ if ConversationCommand.Summarize in conversation_commands:
+ file_filters = conversation.file_filters
+ llm_response = ""
+ if len(file_filters) == 0:
+ llm_response = "No files selected for summarization. Please add files using the section on the left."
+ elif len(file_filters) > 1:
+ llm_response = "Only one file can be selected for summarization."
+ else:
+ try:
+ file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
+ if len(file_object) == 0:
+ llm_response = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
+ return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
+ contextual_data = " ".join([file.raw_text for file in file_object])
+ summarizeStr = "/" + ConversationCommand.Summarize
+ if q.strip() == summarizeStr:
+ q = "Create a general summary of the file"
+ response = await extract_relevant_summary(q, contextual_data)
+ llm_response = str(response)
+ except Exception as e:
+ logger.error(f"Error summarizing file for {user.email}: {e}")
+ llm_response = "Error summarizing file."
+ await sync_to_async(save_to_conversation_log)(
+ q,
+ llm_response,
+ user,
+ conversation.conversation_log,
+ user_message_time,
+ intent_type="summarize",
+ client_application=request.user.client_app,
+ conversation_id=conversation_id,
+ )
+ update_telemetry_state(
+ request=request,
+ telemetry_type="api",
+ api="chat",
+ metadata={"conversation_command": conversation_commands[0].value},
+ **common.__dict__,
+ )
+ return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
+
conversation = await ConversationAdapters.aget_conversation_by_user(
user, request.user.client_app, conversation_id, title
)
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index febb16ea..90bfd8c6 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -200,6 +200,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.Image
elif query.startswith("/automated_task"):
return ConversationCommand.AutomatedTask
+ elif query.startswith("/summarize"):
+ return ConversationCommand.Summarize
# If no relevant notes found for the given query
elif not any_references:
return ConversationCommand.General
@@ -418,7 +420,30 @@ async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
prompts.system_prompt_extract_relevant_information,
chat_model_option=summarizer_model,
)
+ return response.strip()
+
+async def extract_relevant_summary(q: str, corpus: str) -> Union[str, None]:
+ """
+ Extract relevant information for a given query from the target corpus
+ """
+
+ if is_none_or_empty(corpus) or is_none_or_empty(q):
+ return None
+
+ extract_relevant_information = prompts.extract_relevant_summary.format(
+ query=q,
+ corpus=corpus.strip(),
+ )
+
+ summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
+
+ with timer("Chat actor: Extract relevant information from data", logger):
+ response = await send_message_to_model_wrapper(
+ extract_relevant_information,
+ prompts.system_prompt_extract_relevant_summary,
+ chat_model_option=summarizer_model,
+ )
return response.strip()
diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py
index fd44ba85..59327b0d 100644
--- a/src/khoj/utils/helpers.py
+++ b/src/khoj/utils/helpers.py
@@ -307,6 +307,7 @@ class ConversationCommand(str, Enum):
Text = "text"
Automation = "automation"
AutomatedTask = "automated_task"
+ Summarize = "summarize"
command_descriptions = {
@@ -318,6 +319,7 @@ command_descriptions = {
ConversationCommand.Image: "Generate images by describing your imagination in words.",
ConversationCommand.Automation: "Automatically run your query at a specified time or interval.",
ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation",
+ ConversationCommand.Summarize: "Create an appropriate summary using provided documents.",
}
tool_descriptions_for_llm = {
@@ -326,6 +328,7 @@ tool_descriptions_for_llm = {
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
+ ConversationCommand.Summarize: "To create a summary of the document provided by the user.",
}
mode_descriptions_for_llm = {
diff --git a/tests/conftest.py b/tests/conftest.py
index 405b652f..83981e6d 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -19,6 +19,7 @@ from khoj.database.models import (
KhojUser,
LocalMarkdownConfig,
LocalOrgConfig,
+ LocalPdfConfig,
LocalPlaintextConfig,
)
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
@@ -415,6 +416,18 @@ def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser):
return LocalOrgConfig.objects.filter(user=default_user).first()
+@pytest.fixture(scope="function")
+def pdf_configured_user1(default_user: KhojUser):
+ LocalPdfConfig.objects.create(
+ input_files=None,
+ input_filter=["tests/data/pdf/singlepage.pdf"],
+ user=default_user,
+ )
+ # Index Markdown Content for Search
+ all_files = fs_syncer.collect_files(user=default_user)
+ success = configure_content(all_files, user=default_user)
+
+
@pytest.fixture(scope="function")
def sample_org_data():
return get_sample_data("org")
diff --git a/tests/helpers.py b/tests/helpers.py
index 009c8b55..7894ffa2 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -51,7 +51,9 @@ class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
tokenizer = None
chat_model = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
model_type = "offline"
- openai_config = factory.SubFactory(OpenAIProcessorConversationConfigFactory)
+ openai_config = factory.LazyAttribute(
+ lambda obj: OpenAIProcessorConversationConfigFactory() if os.getenv("OPENAI_API_KEY") else None
+ )
class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
diff --git a/tests/test_markdown_to_entries.py b/tests/test_markdown_to_entries.py
index d63f026a..22f94ef5 100644
--- a/tests/test_markdown_to_entries.py
+++ b/tests/test_markdown_to_entries.py
@@ -23,13 +23,14 @@ def test_extract_markdown_with_no_headings(tmp_path):
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
- assert len(entries) == 1
+ assert len(entries) == 2
+ assert len(entries[1]) == 1
# Ensure raw entry with no headings do not get heading prefix prepended
- assert not entries[0].raw.startswith("#")
+ assert not entries[1][0].raw.startswith("#")
# Ensure compiled entry has filename prepended as top level heading
- assert entries[0].compiled.startswith(expected_heading)
+ assert entries[1][0].compiled.startswith(expected_heading)
# Ensure compiled entry also includes the file name
- assert str(tmp_path) in entries[0].compiled
+ assert str(tmp_path) in entries[1][0].compiled
def test_extract_single_markdown_entry(tmp_path):
@@ -48,7 +49,8 @@ def test_extract_single_markdown_entry(tmp_path):
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
- assert len(entries) == 1
+ assert len(entries) == 2
+ assert len(entries[1]) == 1
def test_extract_multiple_markdown_entries(tmp_path):
@@ -72,8 +74,9 @@ def test_extract_multiple_markdown_entries(tmp_path):
# Assert
assert len(entries) == 2
+ assert len(entries[1]) == 2
# Ensure entry compiled strings include the markdown files they originate from
- assert all([tmp_path.stem in entry.compiled for entry in entries])
+ assert all([tmp_path.stem in entry.compiled for entry in entries[1]])
def test_extract_entries_with_different_level_headings(tmp_path):
@@ -94,8 +97,9 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Assert
assert len(entries) == 2
- assert entries[0].raw == "# Heading 1\n## Sub-Heading 1.1", "Ensure entry includes heading ancestory"
- assert entries[1].raw == "# Heading 2\n"
+ assert len(entries[1]) == 2
+ assert entries[1][0].raw == "# Heading 1\n## Sub-Heading 1.1", "Ensure entry includes heading ancestory"
+ assert entries[1][1].raw == "# Heading 2\n"
def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
@@ -116,10 +120,11 @@ def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
- assert len(entries) == 3
- assert entries[0].raw == "# Heading 1\n#### Sub-Heading 1.1", "Ensure entry includes heading ancestory"
- assert entries[1].raw == "# Heading 1\n## Sub-Heading 1.2", "Ensure entry includes heading ancestory"
- assert entries[2].raw == "# Heading 2\n"
+ assert len(entries) == 2
+ assert len(entries[1]) == 3
+ assert entries[1][0].raw == "# Heading 1\n#### Sub-Heading 1.1", "Ensure entry includes heading ancestory"
+ assert entries[1][1].raw == "# Heading 1\n## Sub-Heading 1.2", "Ensure entry includes heading ancestory"
+ assert entries[1][2].raw == "# Heading 2\n"
def test_extract_entries_with_text_before_headings(tmp_path):
@@ -141,10 +146,13 @@ body line 2
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
- assert len(entries) == 3
- assert entries[0].raw == "\nText before headings"
- assert entries[1].raw == "# Heading 1\nbody line 1"
- assert entries[2].raw == "# Heading 1\n## Heading 2\nbody line 2\n", "Ensure raw entry includes heading ancestory"
+ assert len(entries) == 2
+ assert len(entries[1]) == 3
+ assert entries[1][0].raw == "\nText before headings"
+ assert entries[1][1].raw == "# Heading 1\nbody line 1"
+ assert (
+ entries[1][2].raw == "# Heading 1\n## Heading 2\nbody line 2\n"
+ ), "Ensure raw entry includes heading ancestory"
def test_parse_markdown_file_into_single_entry_if_small(tmp_path):
@@ -165,8 +173,9 @@ body line 1.1
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
# Assert
- assert len(entries) == 1
- assert entries[0].raw == entry
+ assert len(entries) == 2
+ assert len(entries[1]) == 1
+ assert entries[1][0].raw == entry
def test_parse_markdown_entry_with_children_as_single_entry_if_small(tmp_path):
@@ -191,13 +200,14 @@ longer body line 2.1
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
# Assert
- assert len(entries) == 3
+ assert len(entries) == 2
+ assert len(entries[1]) == 3
assert (
- entries[0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1"
+ entries[1][0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1"
), "First entry includes children headings"
- assert entries[1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings"
+ assert entries[1][1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings"
assert (
- entries[2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n"
+ entries[1][2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n"
), "Third entry is second entries child heading"
diff --git a/tests/test_offline_chat_director.py b/tests/test_offline_chat_director.py
index b5e546d0..1c4ad3ac 100644
--- a/tests/test_offline_chat_director.py
+++ b/tests/test_offline_chat_director.py
@@ -6,7 +6,7 @@ import pytest
from faker import Faker
from freezegun import freeze_time
-from khoj.database.models import Agent, KhojUser
+from khoj.database.models import Agent, Entry, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_information_sources
@@ -305,6 +305,200 @@ def test_answer_not_known_using_notes_command(client_offline_chat, default_user2
assert response_message == prompts.no_notes_found.format()
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
+@pytest.mark.chatquality
+def test_summarize_one_file(client_offline_chat, default_user2: KhojUser):
+ message_list = []
+ conversation = create_conversation(message_list, default_user2)
+ # post "Xi Li.markdown" file to the file filters
+ file_list = (
+ Entry.objects.filter(user=default_user2, file_source="computer")
+ .distinct("file_path")
+ .values_list("file_path", flat=True)
+ )
+ # pick the file that has "Xi Li.markdown" in the name
+ summarization_file = ""
+ for file in file_list:
+ if "Birthday Gift for Xiu turning 4.markdown" in file:
+ summarization_file = file
+ break
+ assert summarization_file != ""
+
+ response = client_offline_chat.post(
+ "api/chat/conversation/file-filters",
+ json={"filename": summarization_file, "conversation_id": str(conversation.id)},
+ )
+ query = urllib.parse.quote("/summarize")
+ response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+ # Assert
+ assert response_message != ""
+ assert response_message != "No files selected for summarization. Please add files using the section on the left."
+ assert response_message != "Only one file can be selected for summarization."
+
+
+@pytest.mark.django_db(transaction=True)
+@pytest.mark.chatquality
+def test_summarize_extra_text(client_offline_chat, default_user2: KhojUser):
+ message_list = []
+ conversation = create_conversation(message_list, default_user2)
+ # post "Xi Li.markdown" file to the file filters
+ file_list = (
+ Entry.objects.filter(user=default_user2, file_source="computer")
+ .distinct("file_path")
+ .values_list("file_path", flat=True)
+ )
+ # pick the file that has "Xi Li.markdown" in the name
+ summarization_file = ""
+ for file in file_list:
+ if "Birthday Gift for Xiu turning 4.markdown" in file:
+ summarization_file = file
+ break
+ assert summarization_file != ""
+
+ response = client_offline_chat.post(
+ "api/chat/conversation/file-filters",
+ json={"filename": summarization_file, "conversation_id": str(conversation.id)},
+ )
+ query = urllib.parse.quote("/summarize tell me about Xiu")
+ response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+ # Assert
+ assert response_message != ""
+ assert response_message != "No files selected for summarization. Please add files using the section on the left."
+ assert response_message != "Only one file can be selected for summarization."
+
+
+@pytest.mark.django_db(transaction=True)
+@pytest.mark.chatquality
+def test_summarize_multiple_files(client_offline_chat, default_user2: KhojUser):
+ message_list = []
+ conversation = create_conversation(message_list, default_user2)
+ # post "Xi Li.markdown" file to the file filters
+ file_list = (
+ Entry.objects.filter(user=default_user2, file_source="computer")
+ .distinct("file_path")
+ .values_list("file_path", flat=True)
+ )
+
+ response = client_offline_chat.post(
+ "api/chat/conversation/file-filters", json={"filename": file_list[0], "conversation_id": str(conversation.id)}
+ )
+ response = client_offline_chat.post(
+ "api/chat/conversation/file-filters", json={"filename": file_list[1], "conversation_id": str(conversation.id)}
+ )
+
+ query = urllib.parse.quote("/summarize")
+ response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+
+ # Assert
+ assert response_message == "Only one file can be selected for summarization."
+
+
+@pytest.mark.django_db(transaction=True)
+@pytest.mark.chatquality
+def test_summarize_no_files(client_offline_chat, default_user2: KhojUser):
+ message_list = []
+ conversation = create_conversation(message_list, default_user2)
+ query = urllib.parse.quote("/summarize")
+ response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+ # Assert
+ assert response_message == "No files selected for summarization. Please add files using the section on the left."
+
+
+@pytest.mark.django_db(transaction=True)
+@pytest.mark.chatquality
+def test_summarize_different_conversation(client_offline_chat, default_user2: KhojUser):
+ message_list = []
+ conversation1 = create_conversation(message_list, default_user2)
+ conversation2 = create_conversation(message_list, default_user2)
+
+ file_list = (
+ Entry.objects.filter(user=default_user2, file_source="computer")
+ .distinct("file_path")
+ .values_list("file_path", flat=True)
+ )
+ summarization_file = ""
+ for file in file_list:
+ if "Birthday Gift for Xiu turning 4.markdown" in file:
+ summarization_file = file
+ break
+ assert summarization_file != ""
+
+ # add file filter to conversation 1.
+ response = client_offline_chat.post(
+ "api/chat/conversation/file-filters",
+ json={"filename": summarization_file, "conversation_id": str(conversation1.id)},
+ )
+
+ query = urllib.parse.quote("/summarize")
+ response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation2.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+
+ # Assert
+ assert response_message == "No files selected for summarization. Please add files using the section on the left."
+
+ # now make sure that the file filter is still in conversation 1
+ response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation1.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+
+ # Assert
+ assert response_message != ""
+ assert response_message != "No files selected for summarization. Please add files using the section on the left."
+ assert response_message != "Only one file can be selected for summarization."
+
+
+@pytest.mark.django_db(transaction=True)
+@pytest.mark.chatquality
+def test_summarize_nonexistant_file(client_offline_chat, default_user2: KhojUser):
+ message_list = []
+ conversation = create_conversation(message_list, default_user2)
+ # post "imaginary.markdown" file to the file filters
+ response = client_offline_chat.post(
+ "api/chat/conversation/file-filters",
+ json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
+ )
+ query = urllib.parse.quote("/summarize")
+ response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+ # Assert
+ assert response_message == "No files selected for summarization. Please add files using the section on the left."
+
+
+@pytest.mark.django_db(transaction=True)
+@pytest.mark.chatquality
+def test_summarize_diff_user_file(
+ client_offline_chat, default_user: KhojUser, pdf_configured_user1, default_user2: KhojUser
+):
+ message_list = []
+ conversation = create_conversation(message_list, default_user2)
+ # Get the pdf file called singlepage.pdf
+ file_list = (
+ Entry.objects.filter(user=default_user, file_source="computer")
+ .distinct("file_path")
+ .values_list("file_path", flat=True)
+ )
+ summarization_file = ""
+ for file in file_list:
+ if "singlepage.pdf" in file:
+ summarization_file = file
+ break
+ assert summarization_file != ""
+ # add singlepage.pdf to the file filters
+ response = client_offline_chat.post(
+ "api/chat/conversation/file-filters",
+ json={"filename": summarization_file, "conversation_id": str(conversation.id)},
+ )
+ query = urllib.parse.quote("/summarize")
+ response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+ # Assert
+ assert response_message == "No files selected for summarization. Please add files using the section on the left."
+
+
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
@@ -522,7 +716,7 @@ async def test_get_correct_tools_online(client_offline_chat):
user_query = "What's the weather in Patagonia this week?"
# Act
- tools = await aget_relevant_information_sources(user_query, {})
+ tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
# Assert
tools = [tool.value for tool in tools]
@@ -537,7 +731,7 @@ async def test_get_correct_tools_notes(client_offline_chat):
user_query = "Where did I go for my first battleship training?"
# Act
- tools = await aget_relevant_information_sources(user_query, {})
+ tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
# Assert
tools = [tool.value for tool in tools]
@@ -552,7 +746,7 @@ async def test_get_correct_tools_online_or_general_and_notes(client_offline_chat
user_query = "What's the highest point in Patagonia and have I been there?"
# Act
- tools = await aget_relevant_information_sources(user_query, {})
+ tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
# Assert
tools = [tool.value for tool in tools]
@@ -569,7 +763,7 @@ async def test_get_correct_tools_general(client_offline_chat):
user_query = "How many noble gases are there?"
# Act
- tools = await aget_relevant_information_sources(user_query, {})
+ tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
# Assert
tools = [tool.value for tool in tools]
@@ -593,7 +787,7 @@ async def test_get_correct_tools_with_chat_history(client_offline_chat, default_
chat_history = create_conversation(chat_log, default_user2)
# Act
- tools = await aget_relevant_information_sources(user_query, chat_history)
+ tools = await aget_relevant_information_sources(user_query, chat_history, is_task=False)
# Assert
tools = [tool.value for tool in tools]
diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py
index 58184ef9..b547f78e 100644
--- a/tests/test_openai_chat_director.py
+++ b/tests/test_openai_chat_director.py
@@ -5,7 +5,7 @@ from urllib.parse import quote
import pytest
from freezegun import freeze_time
-from khoj.database.models import Agent, KhojUser
+from khoj.database.models import Agent, Entry, KhojUser, LocalPdfConfig
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_information_sources
@@ -289,6 +289,198 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default
assert response_message == prompts.no_entries_found.format()
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
+@pytest.mark.chatquality
+def test_summarize_one_file(chat_client, default_user2: KhojUser):
+ message_list = []
+ conversation = create_conversation(message_list, default_user2)
+ # post "Xi Li.markdown" file to the file filters
+ file_list = (
+ Entry.objects.filter(user=default_user2, file_source="computer")
+ .distinct("file_path")
+ .values_list("file_path", flat=True)
+ )
+ # pick the file that has "Xi Li.markdown" in the name
+ summarization_file = ""
+ for file in file_list:
+ if "Birthday Gift for Xiu turning 4.markdown" in file:
+ summarization_file = file
+ break
+ assert summarization_file != ""
+
+ response = chat_client.post(
+ "api/chat/conversation/file-filters",
+ json={"filename": summarization_file, "conversation_id": str(conversation.id)},
+ )
+ query = urllib.parse.quote("/summarize")
+ response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+ # Assert
+ assert response_message != ""
+ assert response_message != "No files selected for summarization. Please add files using the section on the left."
+ assert response_message != "Only one file can be selected for summarization."
+
+
+@pytest.mark.django_db(transaction=True)
+@pytest.mark.chatquality
+def test_summarize_extra_text(chat_client, default_user2: KhojUser):
+ message_list = []
+ conversation = create_conversation(message_list, default_user2)
+ # post "Xi Li.markdown" file to the file filters
+ file_list = (
+ Entry.objects.filter(user=default_user2, file_source="computer")
+ .distinct("file_path")
+ .values_list("file_path", flat=True)
+ )
+ # pick the file that has "Xi Li.markdown" in the name
+ summarization_file = ""
+ for file in file_list:
+ if "Birthday Gift for Xiu turning 4.markdown" in file:
+ summarization_file = file
+ break
+ assert summarization_file != ""
+
+ response = chat_client.post(
+ "api/chat/conversation/file-filters",
+ json={"filename": summarization_file, "conversation_id": str(conversation.id)},
+ )
+ query = urllib.parse.quote("/summarize tell me about Xiu")
+ response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+ # Assert
+ assert response_message != ""
+ assert response_message != "No files selected for summarization. Please add files using the section on the left."
+ assert response_message != "Only one file can be selected for summarization."
+
+
+@pytest.mark.django_db(transaction=True)
+@pytest.mark.chatquality
+def test_summarize_multiple_files(chat_client, default_user2: KhojUser):
+ message_list = []
+ conversation = create_conversation(message_list, default_user2)
+ # post "Xi Li.markdown" file to the file filters
+ file_list = (
+ Entry.objects.filter(user=default_user2, file_source="computer")
+ .distinct("file_path")
+ .values_list("file_path", flat=True)
+ )
+
+ response = chat_client.post(
+ "api/chat/conversation/file-filters", json={"filename": file_list[0], "conversation_id": str(conversation.id)}
+ )
+ response = chat_client.post(
+ "api/chat/conversation/file-filters", json={"filename": file_list[1], "conversation_id": str(conversation.id)}
+ )
+
+ query = urllib.parse.quote("/summarize")
+ response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+
+ # Assert
+ assert response_message == "Only one file can be selected for summarization."
+
+
+@pytest.mark.django_db(transaction=True)
+@pytest.mark.chatquality
+def test_summarize_no_files(chat_client, default_user2: KhojUser):
+ message_list = []
+ conversation = create_conversation(message_list, default_user2)
+ query = urllib.parse.quote("/summarize")
+ response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+ # Assert
+ assert response_message == "No files selected for summarization. Please add files using the section on the left."
+
+
+@pytest.mark.django_db(transaction=True)
+@pytest.mark.chatquality
+def test_summarize_different_conversation(chat_client, default_user2: KhojUser):
+ message_list = []
+ conversation1 = create_conversation(message_list, default_user2)
+ conversation2 = create_conversation(message_list, default_user2)
+
+ file_list = (
+ Entry.objects.filter(user=default_user2, file_source="computer")
+ .distinct("file_path")
+ .values_list("file_path", flat=True)
+ )
+ summarization_file = ""
+ for file in file_list:
+ if "Birthday Gift for Xiu turning 4.markdown" in file:
+ summarization_file = file
+ break
+ assert summarization_file != ""
+
+ # add file filter to conversation 1.
+ response = chat_client.post(
+ "api/chat/conversation/file-filters",
+ json={"filename": summarization_file, "conversation_id": str(conversation1.id)},
+ )
+
+ query = urllib.parse.quote("/summarize")
+ response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation2.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+
+ # Assert
+ assert response_message == "No files selected for summarization. Please add files using the section on the left."
+
+ # now make sure that the file filter is still in conversation 1
+ response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation1.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+
+ # Assert
+ assert response_message != ""
+ assert response_message != "No files selected for summarization. Please add files using the section on the left."
+ assert response_message != "Only one file can be selected for summarization."
+
+
+@pytest.mark.django_db(transaction=True)
+@pytest.mark.chatquality
+def test_summarize_nonexistant_file(chat_client, default_user2: KhojUser):
+ message_list = []
+ conversation = create_conversation(message_list, default_user2)
+ # post "imaginary.markdown" file to the file filters
+ response = chat_client.post(
+ "api/chat/conversation/file-filters",
+ json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
+ )
+ query = urllib.parse.quote("/summarize")
+ response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+ # Assert
+ assert response_message == "No files selected for summarization. Please add files using the section on the left."
+
+
+@pytest.mark.django_db(transaction=True)
+@pytest.mark.chatquality
+def test_summarize_diff_user_file(chat_client, default_user: KhojUser, pdf_configured_user1, default_user2: KhojUser):
+ message_list = []
+ conversation = create_conversation(message_list, default_user2)
+ # Get the pdf file called singlepage.pdf
+ file_list = (
+ Entry.objects.filter(user=default_user, file_source="computer")
+ .distinct("file_path")
+ .values_list("file_path", flat=True)
+ )
+ summarization_file = ""
+ for file in file_list:
+ if "singlepage.pdf" in file:
+ summarization_file = file
+ break
+ assert summarization_file != ""
+ # add singlepage.pdf to the file filters
+ response = chat_client.post(
+ "api/chat/conversation/file-filters",
+ json={"filename": summarization_file, "conversation_id": str(conversation.id)},
+ )
+ query = urllib.parse.quote("/summarize")
+ response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
+ response_message = response.content.decode("utf-8")
+ # Assert
+ assert response_message == "No files selected for summarization. Please add files using the section on the left."
+
+
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
@pytest.mark.django_db(transaction=True)
diff --git a/tests/test_org_to_entries.py b/tests/test_org_to_entries.py
index 56be5fa5..e8940269 100644
--- a/tests/test_org_to_entries.py
+++ b/tests/test_org_to_entries.py
@@ -33,10 +33,12 @@ def test_configure_indexing_heading_only_entries(tmp_path):
# Assert
if index_heading_entries:
# Entry with empty body indexed when index_heading_entries set to True
- assert len(entries) == 1
+ assert len(entries) == 2
+ assert len(entries[1]) == 1
else:
# Entry with empty body ignored when index_heading_entries set to False
- assert is_none_or_empty(entries)
+ assert len(entries) == 2
+ assert is_none_or_empty(entries[1])
def test_entry_split_when_exceeds_max_tokens():
@@ -55,9 +57,9 @@ def test_entry_split_when_exceeds_max_tokens():
# Act
# Extract Entries from specified Org files
entries = OrgToEntries.extract_org_entries(org_files=data)
-
+ assert len(entries) == 2
# Split each entry from specified Org files by max tokens
- entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=5)
+ entries = TextToEntries.split_entries_by_max_tokens(entries[1], max_tokens=5)
# Assert
assert len(entries) == 2
@@ -114,11 +116,12 @@ body line 1.1
# Act
# Extract Entries from specified Org files
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12)
- for entry in extracted_entries:
+ assert len(extracted_entries) == 2
+ for entry in extracted_entries[1]:
entry.raw = clean(entry.raw)
# Assert
- assert len(extracted_entries) == 1
+ assert len(extracted_entries[1]) == 1
assert entry.raw == expected_entry
@@ -165,10 +168,11 @@ longer body line 2.1
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12)
# Assert
- assert len(extracted_entries) == 3
- assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings"
- assert extracted_entries[1].compiled == second_expected_entry, "Second entry does not include children headings"
- assert extracted_entries[2].compiled == third_expected_entry, "Third entry is second entries child heading"
+ assert len(extracted_entries) == 2
+ assert len(extracted_entries[1]) == 3
+ assert extracted_entries[1][0].compiled == first_expected_entry, "First entry includes children headings"
+ assert extracted_entries[1][1].compiled == second_expected_entry, "Second entry does not include children headings"
+ assert extracted_entries[1][2].compiled == third_expected_entry, "Third entry is second entries child heading"
def test_separate_sibling_org_entries_if_all_cannot_fit_in_token_limit(tmp_path):
@@ -226,10 +230,11 @@ body line 3.1
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=30)
# Assert
- assert len(extracted_entries) == 3
- assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings"
- assert extracted_entries[1].compiled == second_expected_entry, "Second entry includes children headings"
- assert extracted_entries[2].compiled == third_expected_entry, "Third entry includes children headings"
+ assert len(extracted_entries) == 2
+ assert len(extracted_entries[1]) == 3
+ assert extracted_entries[1][0].compiled == first_expected_entry, "First entry includes children headings"
+ assert extracted_entries[1][1].compiled == second_expected_entry, "Second entry includes children headings"
+ assert extracted_entries[1][2].compiled == third_expected_entry, "Third entry includes children headings"
def test_entry_with_body_to_entry(tmp_path):
@@ -251,7 +256,8 @@ def test_entry_with_body_to_entry(tmp_path):
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
# Assert
- assert len(entries) == 1
+ assert len(entries) == 2
+ assert len(entries[1]) == 1
def test_file_with_entry_after_intro_text_to_entry(tmp_path):
@@ -273,6 +279,7 @@ Intro text
# Assert
assert len(entries) == 2
+ assert len(entries[1]) == 2
def test_file_with_no_headings_to_entry(tmp_path):
@@ -291,7 +298,8 @@ def test_file_with_no_headings_to_entry(tmp_path):
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
# Assert
- assert len(entries) == 1
+ assert len(entries) == 2
+ assert len(entries[1]) == 1
def test_get_org_files(tmp_path):
@@ -349,13 +357,14 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Act
# Extract Entries from specified Org files
entries = OrgToEntries.extract_org_entries(org_files=data, index_heading_entries=True, max_tokens=3)
- for entry in entries:
+ assert len(entries) == 2
+ for entry in entries[1]:
entry.raw = clean(f"{entry.raw}")
# Assert
- assert len(entries) == 2
- assert entries[0].raw == "* Heading 1\n** Sub-Heading 1.1\n", "Ensure entry includes heading ancestory"
- assert entries[1].raw == "* Heading 2\n"
+ assert len(entries[1]) == 2
+ assert entries[1][0].raw == "* Heading 1\n** Sub-Heading 1.1\n", "Ensure entry includes heading ancestory"
+ assert entries[1][1].raw == "* Heading 2\n"
# Helper Functions
diff --git a/tests/test_pdf_to_entries.py b/tests/test_pdf_to_entries.py
index a8c6aa43..31ccb387 100644
--- a/tests/test_pdf_to_entries.py
+++ b/tests/test_pdf_to_entries.py
@@ -17,7 +17,8 @@ def test_single_page_pdf_to_jsonl():
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
# Assert
- assert len(entries) == 1
+ assert len(entries) == 2
+ assert len(entries[1]) == 1
def test_multi_page_pdf_to_jsonl():
@@ -31,7 +32,8 @@ def test_multi_page_pdf_to_jsonl():
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
# Assert
- assert len(entries) == 6
+ assert len(entries) == 2
+ assert len(entries[1]) == 6
def test_ocr_page_pdf_to_jsonl():
@@ -43,9 +45,9 @@ def test_ocr_page_pdf_to_jsonl():
data = {"tests/data/pdf/ocr_samples.pdf": pdf_bytes}
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
-
- assert len(entries) == 1
- assert "playing on a strip of marsh" in entries[0].raw
+ assert len(entries) == 2
+ assert len(entries[1]) == 1
+ assert "playing on a strip of marsh" in entries[1][0].raw
def test_get_pdf_files(tmp_path):
diff --git a/tests/test_plaintext_to_entries.py b/tests/test_plaintext_to_entries.py
index 972a0698..a085b2b5 100644
--- a/tests/test_plaintext_to_entries.py
+++ b/tests/test_plaintext_to_entries.py
@@ -25,15 +25,16 @@ def test_plaintext_file(tmp_path):
entries = PlaintextToEntries.extract_plaintext_entries(data)
# Convert each entry.file to absolute path to make them JSON serializable
- for entry in entries:
+ for entry in entries[1]:
entry.file = str(Path(entry.file).absolute())
# Assert
- assert len(entries) == 1
+ assert len(entries) == 2
+ assert len(entries[1]) == 1
# Ensure raw entry with no headings do not get heading prefix prepended
- assert not entries[0].raw.startswith("#")
+ assert not entries[1][0].raw.startswith("#")
# Ensure compiled entry has filename prepended as top level heading
- assert entries[0].compiled == f"{plaintextfile}\n{raw_entry}"
+ assert entries[1][0].compiled == f"{plaintextfile}\n{raw_entry}"
def test_get_plaintext_files(tmp_path):
@@ -94,8 +95,9 @@ def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
entries = PlaintextToEntries.extract_plaintext_entries(extracted_plaintext_files)
# Assert
- assert len(entries) == 1
- assert "
" not in entries[0].raw
+ assert len(entries) == 2
+ assert len(entries[1]) == 1
+ assert "
" not in entries[1][0].raw
def test_large_plaintext_file_split_into_multiple_entries(tmp_path):
@@ -113,13 +115,20 @@ def test_large_plaintext_file_split_into_multiple_entries(tmp_path):
# Act
# Extract Entries from specified plaintext files
+ normal_entries = PlaintextToEntries.extract_plaintext_entries(normal_data)
+ large_entries = PlaintextToEntries.extract_plaintext_entries(large_data)
+
+ # assert
+ assert len(normal_entries) == 2
+ assert len(large_entries) == 2
+
normal_entries = PlaintextToEntries.split_entries_by_max_tokens(
- PlaintextToEntries.extract_plaintext_entries(normal_data),
+ normal_entries[1],
max_tokens=max_tokens,
raw_is_compiled=True,
)
large_entries = PlaintextToEntries.split_entries_by_max_tokens(
- PlaintextToEntries.extract_plaintext_entries(large_data), max_tokens=max_tokens, raw_is_compiled=True
+ large_entries[1], max_tokens=max_tokens, raw_is_compiled=True
)
# Assert