mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Add Ability to Summarize Documents (#800)
* Uses entire file text and summarizer model to generate document summary. * Uses the contents of the user's query to create a tailored summary. * Integrates with File Filters #788 for a better UX.
This commit is contained in:
parent
677d49d438
commit
d4e5c95711
21 changed files with 791 additions and 85 deletions
|
@ -28,6 +28,7 @@ from khoj.database.models import (
|
|||
ClientApplication,
|
||||
Conversation,
|
||||
Entry,
|
||||
FileObject,
|
||||
GithubConfig,
|
||||
GithubRepoConfig,
|
||||
GoogleUser,
|
||||
|
@ -731,7 +732,7 @@ class ConversationAdapters:
|
|||
if server_chat_settings is None or (
|
||||
server_chat_settings.summarizer_model is None and server_chat_settings.default_model is None
|
||||
):
|
||||
return await ChatModelOptions.objects.filter().afirst()
|
||||
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
|
||||
return server_chat_settings.summarizer_model or server_chat_settings.default_model
|
||||
|
||||
@staticmethod
|
||||
|
@ -846,6 +847,58 @@ class ConversationAdapters:
|
|||
return await TextToImageModelConfig.objects.filter().afirst()
|
||||
|
||||
|
||||
class FileObjectAdapters:
|
||||
@staticmethod
|
||||
def update_raw_text(file_object: FileObject, new_raw_text: str):
|
||||
file_object.raw_text = new_raw_text
|
||||
file_object.save()
|
||||
|
||||
@staticmethod
|
||||
def create_file_object(user: KhojUser, file_name: str, raw_text: str):
|
||||
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
|
||||
|
||||
@staticmethod
|
||||
def get_file_objects_by_name(user: KhojUser, file_name: str):
|
||||
return FileObject.objects.filter(user=user, file_name=file_name).first()
|
||||
|
||||
@staticmethod
|
||||
def get_all_file_objects(user: KhojUser):
|
||||
return FileObject.objects.filter(user=user).all()
|
||||
|
||||
@staticmethod
|
||||
def delete_file_object_by_name(user: KhojUser, file_name: str):
|
||||
return FileObject.objects.filter(user=user, file_name=file_name).delete()
|
||||
|
||||
@staticmethod
|
||||
def delete_all_file_objects(user: KhojUser):
|
||||
return FileObject.objects.filter(user=user).delete()
|
||||
|
||||
@staticmethod
|
||||
async def async_update_raw_text(file_object: FileObject, new_raw_text: str):
|
||||
file_object.raw_text = new_raw_text
|
||||
await file_object.asave()
|
||||
|
||||
@staticmethod
|
||||
async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str):
|
||||
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
|
||||
|
||||
@staticmethod
|
||||
async def async_get_file_objects_by_name(user: KhojUser, file_name: str):
|
||||
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name))
|
||||
|
||||
@staticmethod
|
||||
async def async_get_all_file_objects(user: KhojUser):
|
||||
return await sync_to_async(list)(FileObject.objects.filter(user=user))
|
||||
|
||||
@staticmethod
|
||||
async def async_delete_file_object_by_name(user: KhojUser, file_name: str):
|
||||
return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
|
||||
|
||||
@staticmethod
|
||||
async def async_delete_all_file_objects(user: KhojUser):
|
||||
return await FileObject.objects.filter(user=user).adelete()
|
||||
|
||||
|
||||
class EntryAdapters:
|
||||
word_filer = WordFilter()
|
||||
file_filter = FileFilter()
|
||||
|
|
37
src/khoj/database/migrations/0045_fileobject.py
Normal file
37
src/khoj/database/migrations/0045_fileobject.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
# Generated by Django 4.2.11 on 2024-06-14 06:13
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0044_conversation_file_filters"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="FileObject",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("file_name", models.CharField(blank=True, default=None, max_length=400, null=True)),
|
||||
("raw_text", models.TextField()),
|
||||
(
|
||||
"user",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
|
@ -326,6 +326,13 @@ class Entry(BaseModel):
|
|||
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
|
||||
|
||||
|
||||
class FileObject(BaseModel):
|
||||
# Same as Entry but raw will be a much larger string
|
||||
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||
raw_text = models.TextField()
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class EntryDates(BaseModel):
|
||||
date = models.DateField()
|
||||
entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")
|
||||
|
|
|
@ -2145,7 +2145,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
<div style="border-top: 1px solid black; ">
|
||||
<div style="display: flex; align-items: center; justify-content: space-between; margin-bottom: 5px; margin-top: 5px;">
|
||||
<p style="margin: 0;">Files</p>
|
||||
<svg id="file-toggle-button" class="file-toggle-button" style="width:20px; height:20px; position: relative; top: 2px" viewBox="0 0 40 40" fill="#000000" xmlns="http://www.w3.org/2000/svg">
|
||||
<svg class="file-toggle-button" style="width:20px; height:20px; position: relative; top: 2px" viewBox="0 0 40 40" fill="#000000" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M16 0c-8.836 0-16 7.163-16 16s7.163 16 16 16c8.837 0 16-7.163 16-16s-7.163-16-16-16zM16 30.032c-7.72 0-14-6.312-14-14.032s6.28-14 14-14 14 6.28 14 14-6.28 14.032-14 14.032zM23 15h-6v-6c0-0.552-0.448-1-1-1s-1 0.448-1 1v6h-6c-0.552 0-1 0.448-1 1s0.448 1 1 1h6v6c0 0.552 0.448 1 1 1s1-0.448 1-1v-6h6c0.552 0 1-0.448 1-1s-0.448-1-1-1z"></path>
|
||||
</svg>
|
||||
</div>
|
||||
|
|
|
@ -33,7 +33,7 @@ class MarkdownToEntries(TextToEntries):
|
|||
max_tokens = 256
|
||||
# Extract Entries from specified Markdown files
|
||||
with timer("Extract entries from specified Markdown files", logger):
|
||||
current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
|
||||
file_to_text_map, current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
|
@ -50,27 +50,30 @@ class MarkdownToEntries(TextToEntries):
|
|||
deletion_file_names,
|
||||
user,
|
||||
regenerate=regenerate,
|
||||
file_to_text_map=file_to_text_map,
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def extract_markdown_entries(markdown_files, max_tokens=256) -> List[Entry]:
|
||||
def extract_markdown_entries(markdown_files, max_tokens=256) -> Tuple[Dict, List[Entry]]:
|
||||
"Extract entries by heading from specified Markdown files"
|
||||
entries: List[str] = []
|
||||
entry_to_file_map: List[Tuple[str, str]] = []
|
||||
file_to_text_map = dict()
|
||||
for markdown_file in markdown_files:
|
||||
try:
|
||||
markdown_content = markdown_files[markdown_file]
|
||||
entries, entry_to_file_map = MarkdownToEntries.process_single_markdown_file(
|
||||
markdown_content, markdown_file, entries, entry_to_file_map, max_tokens
|
||||
)
|
||||
file_to_text_map[markdown_file] = markdown_content
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unable to process file: {markdown_file}. This file will not be indexed.\n{e}", exc_info=True
|
||||
)
|
||||
|
||||
return MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map))
|
||||
return file_to_text_map, MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map))
|
||||
|
||||
@staticmethod
|
||||
def process_single_markdown_file(
|
||||
|
|
|
@ -33,7 +33,7 @@ class OrgToEntries(TextToEntries):
|
|||
# Extract Entries from specified Org files
|
||||
max_tokens = 256
|
||||
with timer("Extract entries from specified Org files", logger):
|
||||
current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
|
||||
file_to_text_map, current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
|
||||
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=max_tokens)
|
||||
|
@ -49,6 +49,7 @@ class OrgToEntries(TextToEntries):
|
|||
deletion_file_names,
|
||||
user,
|
||||
regenerate=regenerate,
|
||||
file_to_text_map=file_to_text_map,
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
@ -56,26 +57,32 @@ class OrgToEntries(TextToEntries):
|
|||
@staticmethod
|
||||
def extract_org_entries(
|
||||
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=256
|
||||
) -> List[Entry]:
|
||||
) -> Tuple[Dict, List[Entry]]:
|
||||
"Extract entries from specified Org files"
|
||||
entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
|
||||
return OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map, index_heading_entries)
|
||||
file_to_text_map, entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
|
||||
return file_to_text_map, OrgToEntries.convert_org_nodes_to_entries(
|
||||
entries, entry_to_file_map, index_heading_entries
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def extract_org_nodes(org_files: dict[str, str], max_tokens) -> Tuple[List[List[Orgnode]], Dict[Orgnode, str]]:
|
||||
def extract_org_nodes(
|
||||
org_files: dict[str, str], max_tokens
|
||||
) -> Tuple[Dict, List[List[Orgnode]], Dict[Orgnode, str]]:
|
||||
"Extract org nodes from specified org files"
|
||||
entries: List[List[Orgnode]] = []
|
||||
entry_to_file_map: List[Tuple[Orgnode, str]] = []
|
||||
file_to_text_map = {}
|
||||
for org_file in org_files:
|
||||
try:
|
||||
org_content = org_files[org_file]
|
||||
entries, entry_to_file_map = OrgToEntries.process_single_org_file(
|
||||
org_content, org_file, entries, entry_to_file_map, max_tokens
|
||||
)
|
||||
file_to_text_map[org_file] = org_content
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to process file: {org_file}. Skipped indexing it.\nError; {e}", exc_info=True)
|
||||
|
||||
return entries, dict(entry_to_file_map)
|
||||
return file_to_text_map, entries, dict(entry_to_file_map)
|
||||
|
||||
@staticmethod
|
||||
def process_single_org_file(
|
||||
|
|
|
@ -2,10 +2,12 @@ import base64
|
|||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from langchain_community.document_loaders import PyMuPDFLoader
|
||||
|
||||
# importing FileObjectAdapter so that we can add new files and debug file object db.
|
||||
# from khoj.database.adapters import FileObjectAdapters
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
|
@ -33,7 +35,7 @@ class PdfToEntries(TextToEntries):
|
|||
|
||||
# Extract Entries from specified Pdf files
|
||||
with timer("Extract entries from specified PDF files", logger):
|
||||
current_entries = PdfToEntries.extract_pdf_entries(files)
|
||||
file_to_text_map, current_entries = PdfToEntries.extract_pdf_entries(files)
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
|
@ -50,14 +52,15 @@ class PdfToEntries(TextToEntries):
|
|||
deletion_file_names,
|
||||
user,
|
||||
regenerate=regenerate,
|
||||
file_to_text_map=file_to_text_map,
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def extract_pdf_entries(pdf_files) -> List[Entry]:
|
||||
def extract_pdf_entries(pdf_files) -> Tuple[Dict, List[Entry]]: # important function
|
||||
"""Extract entries by page from specified PDF files"""
|
||||
|
||||
file_to_text_map = dict()
|
||||
entries: List[str] = []
|
||||
entry_to_location_map: List[Tuple[str, str]] = []
|
||||
for pdf_file in pdf_files:
|
||||
|
@ -73,9 +76,14 @@ class PdfToEntries(TextToEntries):
|
|||
pdf_entries_per_file = [page.page_content for page in loader.load()]
|
||||
except ImportError:
|
||||
loader = PyMuPDFLoader(f"{tmp_file}")
|
||||
pdf_entries_per_file = [page.page_content for page in loader.load()]
|
||||
entry_to_location_map += zip(pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file))
|
||||
pdf_entries_per_file = [
|
||||
page.page_content for page in loader.load()
|
||||
] # page_content items list for a given pdf.
|
||||
entry_to_location_map += zip(
|
||||
pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file)
|
||||
) # this is an indexed map of pdf_entries for the pdf.
|
||||
entries.extend(pdf_entries_per_file)
|
||||
file_to_text_map[pdf_file] = pdf_entries_per_file
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
@ -83,7 +91,7 @@ class PdfToEntries(TextToEntries):
|
|||
if os.path.exists(f"{tmp_file}"):
|
||||
os.remove(f"{tmp_file}")
|
||||
|
||||
return PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
|
||||
return file_to_text_map, PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
|
||||
|
||||
@staticmethod
|
||||
def convert_pdf_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
|
||||
|
|
|
@ -32,7 +32,7 @@ class PlaintextToEntries(TextToEntries):
|
|||
|
||||
# Extract Entries from specified plaintext files
|
||||
with timer("Extract entries from specified Plaintext files", logger):
|
||||
current_entries = PlaintextToEntries.extract_plaintext_entries(files)
|
||||
file_to_text_map, current_entries = PlaintextToEntries.extract_plaintext_entries(files)
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
|
@ -49,6 +49,7 @@ class PlaintextToEntries(TextToEntries):
|
|||
deletion_filenames=deletion_file_names,
|
||||
user=user,
|
||||
regenerate=regenerate,
|
||||
file_to_text_map=file_to_text_map,
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
@ -63,21 +64,23 @@ class PlaintextToEntries(TextToEntries):
|
|||
return soup.get_text(strip=True, separator="\n")
|
||||
|
||||
@staticmethod
|
||||
def extract_plaintext_entries(text_files: Dict[str, str]) -> List[Entry]:
|
||||
def extract_plaintext_entries(text_files: Dict[str, str]) -> Tuple[Dict, List[Entry]]:
|
||||
entries: List[str] = []
|
||||
entry_to_file_map: List[Tuple[str, str]] = []
|
||||
file_to_text_map = dict()
|
||||
for text_file in text_files:
|
||||
try:
|
||||
text_content = text_files[text_file]
|
||||
entries, entry_to_file_map = PlaintextToEntries.process_single_plaintext_file(
|
||||
text_content, text_file, entries, entry_to_file_map
|
||||
)
|
||||
file_to_text_map[text_file] = text_content
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to read file: {text_file} as plaintext. Skipping file.")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
# Extract Entries from specified plaintext files
|
||||
return PlaintextToEntries.convert_text_files_to_entries(entries, dict(entry_to_file_map))
|
||||
return file_to_text_map, PlaintextToEntries.convert_text_files_to_entries(entries, dict(entry_to_file_map))
|
||||
|
||||
@staticmethod
|
||||
def process_single_plaintext_file(
|
||||
|
|
|
@ -9,7 +9,11 @@ from typing import Any, Callable, List, Set, Tuple
|
|||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from tqdm import tqdm
|
||||
|
||||
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
||||
from khoj.database.adapters import (
|
||||
EntryAdapters,
|
||||
FileObjectAdapters,
|
||||
get_user_search_model_or_default,
|
||||
)
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import EntryDates, KhojUser
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
|
@ -120,6 +124,7 @@ class TextToEntries(ABC):
|
|||
deletion_filenames: Set[str] = None,
|
||||
user: KhojUser = None,
|
||||
regenerate: bool = False,
|
||||
file_to_text_map: dict[str, List[str]] = None,
|
||||
):
|
||||
with timer("Constructed current entry hashes in", logger):
|
||||
hashes_by_file = dict[str, set[str]]()
|
||||
|
@ -186,6 +191,18 @@ class TextToEntries(ABC):
|
|||
logger.error(f"Error adding entries to database:\n{batch_indexing_error}\n---\n{e}", exc_info=True)
|
||||
logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
|
||||
|
||||
if file_to_text_map:
|
||||
# get the list of file_names using added_entries
|
||||
filenames_to_update = [entry.file_path for entry in added_entries]
|
||||
# for each file_name in filenames_to_update, try getting the file object and updating raw_text and if it fails create a new file object
|
||||
for file_name in filenames_to_update:
|
||||
raw_text = " ".join(file_to_text_map[file_name])
|
||||
file_object = FileObjectAdapters.get_file_objects_by_name(user, file_name)
|
||||
if file_object:
|
||||
FileObjectAdapters.update_raw_text(file_object, raw_text)
|
||||
else:
|
||||
FileObjectAdapters.create_file_object(user, file_name, raw_text)
|
||||
|
||||
new_dates = []
|
||||
with timer("Indexed dates from added entries in", logger):
|
||||
for added_entry in added_entries:
|
||||
|
@ -210,6 +227,7 @@ class TextToEntries(ABC):
|
|||
for file_path in deletion_filenames:
|
||||
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
||||
num_deleted_entries += deleted_count
|
||||
FileObjectAdapters.delete_file_object_by_name(user, file_path)
|
||||
|
||||
return len(added_entries), num_deleted_entries
|
||||
|
||||
|
|
|
@ -321,6 +321,27 @@ Collate only relevant information from the website to answer the target query.
|
|||
""".strip()
|
||||
)
|
||||
|
||||
system_prompt_extract_relevant_summary = """As a professional analyst, create a comprehensive report of the most relevant information from the document in response to a user's query. The text provided is directly from within the document. The report you create should be multiple paragraphs, and it should represent the content of the document. Tell the user exactly what the document says in response to their query, while adhering to these guidelines:
|
||||
|
||||
1. Answer the user's query as specifically as possible. Include many supporting details from the document.
|
||||
2. Craft a report that is detailed, thorough, in-depth, and complex, while maintaining clarity.
|
||||
3. Rely strictly on the provided text, without including external information.
|
||||
4. Format the report in multiple paragraphs with a clear structure.
|
||||
5. Be as specific as possible in your answer to the user's query.
|
||||
6. Reproduce as much of the provided text as possible, while maintaining readability.
|
||||
""".strip()
|
||||
|
||||
extract_relevant_summary = PromptTemplate.from_template(
|
||||
"""
|
||||
Target Query: {query}
|
||||
|
||||
Document Contents:
|
||||
{corpus}
|
||||
|
||||
Collate only relevant information from the document to answer the target query.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
pick_relevant_output_mode = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an excellent analyst for selecting the correct way to respond to a user's query. You have access to a limited set of modes for your response. You can only use one of these modes.
|
||||
|
|
|
@ -16,6 +16,7 @@ from websockets import ConnectionClosedOK
|
|||
from khoj.database.adapters import (
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
FileObjectAdapters,
|
||||
PublicConversationAdapters,
|
||||
aget_user_name,
|
||||
)
|
||||
|
@ -42,6 +43,7 @@ from khoj.routers.helpers import (
|
|||
aget_relevant_output_modes,
|
||||
construct_automation_created_message,
|
||||
create_automation,
|
||||
extract_relevant_summary,
|
||||
get_conversation_command,
|
||||
is_query_empty,
|
||||
is_ready_to_chat,
|
||||
|
@ -586,6 +588,51 @@ async def websocket_endpoint(
|
|||
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
|
||||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
|
||||
if ConversationCommand.Summarize in conversation_commands:
|
||||
file_filters = conversation.file_filters
|
||||
response_log = ""
|
||||
if len(file_filters) == 0:
|
||||
response_log = "No files selected for summarization. Please add files using the section on the left."
|
||||
await send_complete_llm_response(response_log)
|
||||
elif len(file_filters) > 1:
|
||||
response_log = "Only one file can be selected for summarization."
|
||||
await send_complete_llm_response(response_log)
|
||||
else:
|
||||
try:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||
if len(file_object) == 0:
|
||||
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
|
||||
await send_complete_llm_response(response_log)
|
||||
continue
|
||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||
if not q:
|
||||
q = "Create a general summary of the file"
|
||||
await send_status_update(f"**🧑🏾💻 Constructing Summary Using:** {file_object[0].file_name}")
|
||||
response = await extract_relevant_summary(q, contextual_data)
|
||||
response_log = str(response)
|
||||
await send_complete_llm_response(response_log)
|
||||
except Exception as e:
|
||||
response_log = "Error summarizing file."
|
||||
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
||||
await send_complete_llm_response(response_log)
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
response_log,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
intent_type="summarize",
|
||||
client_application=websocket.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
update_telemetry_state(
|
||||
request=websocket,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
)
|
||||
continue
|
||||
|
||||
custom_filters = []
|
||||
if conversation_commands == [ConversationCommand.Help]:
|
||||
if not q:
|
||||
|
@ -828,6 +875,49 @@ async def chat(
|
|||
_custom_filters.append("site:khoj.dev")
|
||||
conversation_commands.append(ConversationCommand.Online)
|
||||
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(user, conversation_id=conversation_id)
|
||||
conversation_id = conversation.id if conversation else None
|
||||
if ConversationCommand.Summarize in conversation_commands:
|
||||
file_filters = conversation.file_filters
|
||||
llm_response = ""
|
||||
if len(file_filters) == 0:
|
||||
llm_response = "No files selected for summarization. Please add files using the section on the left."
|
||||
elif len(file_filters) > 1:
|
||||
llm_response = "Only one file can be selected for summarization."
|
||||
else:
|
||||
try:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||
if len(file_object) == 0:
|
||||
llm_response = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
|
||||
return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
|
||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||
summarizeStr = "/" + ConversationCommand.Summarize
|
||||
if q.strip() == summarizeStr:
|
||||
q = "Create a general summary of the file"
|
||||
response = await extract_relevant_summary(q, contextual_data)
|
||||
llm_response = str(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error summarizing file for {user.email}: {e}")
|
||||
llm_response = "Error summarizing file."
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
llm_response,
|
||||
user,
|
||||
conversation.conversation_log,
|
||||
user_message_time,
|
||||
intent_type="summarize",
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
**common.__dict__,
|
||||
)
|
||||
return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
|
||||
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||
user, request.user.client_app, conversation_id, title
|
||||
)
|
||||
|
|
|
@ -200,6 +200,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
|
|||
return ConversationCommand.Image
|
||||
elif query.startswith("/automated_task"):
|
||||
return ConversationCommand.AutomatedTask
|
||||
elif query.startswith("/summarize"):
|
||||
return ConversationCommand.Summarize
|
||||
# If no relevant notes found for the given query
|
||||
elif not any_references:
|
||||
return ConversationCommand.General
|
||||
|
@ -418,7 +420,30 @@ async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
|
|||
prompts.system_prompt_extract_relevant_information,
|
||||
chat_model_option=summarizer_model,
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
|
||||
async def extract_relevant_summary(q: str, corpus: str) -> Union[str, None]:
|
||||
"""
|
||||
Extract relevant information for a given query from the target corpus
|
||||
"""
|
||||
|
||||
if is_none_or_empty(corpus) or is_none_or_empty(q):
|
||||
return None
|
||||
|
||||
extract_relevant_information = prompts.extract_relevant_summary.format(
|
||||
query=q,
|
||||
corpus=corpus.strip(),
|
||||
)
|
||||
|
||||
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
|
||||
|
||||
with timer("Chat actor: Extract relevant information from data", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
extract_relevant_information,
|
||||
prompts.system_prompt_extract_relevant_summary,
|
||||
chat_model_option=summarizer_model,
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
|
||||
|
|
|
@ -307,6 +307,7 @@ class ConversationCommand(str, Enum):
|
|||
Text = "text"
|
||||
Automation = "automation"
|
||||
AutomatedTask = "automated_task"
|
||||
Summarize = "summarize"
|
||||
|
||||
|
||||
command_descriptions = {
|
||||
|
@ -318,6 +319,7 @@ command_descriptions = {
|
|||
ConversationCommand.Image: "Generate images by describing your imagination in words.",
|
||||
ConversationCommand.Automation: "Automatically run your query at a specified time or interval.",
|
||||
ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation",
|
||||
ConversationCommand.Summarize: "Create an appropriate summary using provided documents.",
|
||||
}
|
||||
|
||||
tool_descriptions_for_llm = {
|
||||
|
@ -326,6 +328,7 @@ tool_descriptions_for_llm = {
|
|||
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
|
||||
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
|
||||
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
|
||||
ConversationCommand.Summarize: "To create a summary of the document provided by the user.",
|
||||
}
|
||||
|
||||
mode_descriptions_for_llm = {
|
||||
|
|
|
@ -19,6 +19,7 @@ from khoj.database.models import (
|
|||
KhojUser,
|
||||
LocalMarkdownConfig,
|
||||
LocalOrgConfig,
|
||||
LocalPdfConfig,
|
||||
LocalPlaintextConfig,
|
||||
)
|
||||
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
||||
|
@ -415,6 +416,18 @@ def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser):
|
|||
return LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def pdf_configured_user1(default_user: KhojUser):
|
||||
LocalPdfConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/pdf/singlepage.pdf"],
|
||||
user=default_user,
|
||||
)
|
||||
# Index Markdown Content for Search
|
||||
all_files = fs_syncer.collect_files(user=default_user)
|
||||
success = configure_content(all_files, user=default_user)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def sample_org_data():
|
||||
return get_sample_data("org")
|
||||
|
|
|
@ -51,7 +51,9 @@ class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
|
|||
tokenizer = None
|
||||
chat_model = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
|
||||
model_type = "offline"
|
||||
openai_config = factory.SubFactory(OpenAIProcessorConversationConfigFactory)
|
||||
openai_config = factory.LazyAttribute(
|
||||
lambda obj: OpenAIProcessorConversationConfigFactory() if os.getenv("OPENAI_API_KEY") else None
|
||||
)
|
||||
|
||||
|
||||
class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
|
||||
|
|
|
@ -23,13 +23,14 @@ def test_extract_markdown_with_no_headings(tmp_path):
|
|||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 1
|
||||
# Ensure raw entry with no headings do not get heading prefix prepended
|
||||
assert not entries[0].raw.startswith("#")
|
||||
assert not entries[1][0].raw.startswith("#")
|
||||
# Ensure compiled entry has filename prepended as top level heading
|
||||
assert entries[0].compiled.startswith(expected_heading)
|
||||
assert entries[1][0].compiled.startswith(expected_heading)
|
||||
# Ensure compiled entry also includes the file name
|
||||
assert str(tmp_path) in entries[0].compiled
|
||||
assert str(tmp_path) in entries[1][0].compiled
|
||||
|
||||
|
||||
def test_extract_single_markdown_entry(tmp_path):
|
||||
|
@ -48,7 +49,8 @@ def test_extract_single_markdown_entry(tmp_path):
|
|||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 1
|
||||
|
||||
|
||||
def test_extract_multiple_markdown_entries(tmp_path):
|
||||
|
@ -72,8 +74,9 @@ def test_extract_multiple_markdown_entries(tmp_path):
|
|||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 2
|
||||
# Ensure entry compiled strings include the markdown files they originate from
|
||||
assert all([tmp_path.stem in entry.compiled for entry in entries])
|
||||
assert all([tmp_path.stem in entry.compiled for entry in entries[1]])
|
||||
|
||||
|
||||
def test_extract_entries_with_different_level_headings(tmp_path):
|
||||
|
@ -94,8 +97,9 @@ def test_extract_entries_with_different_level_headings(tmp_path):
|
|||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert entries[0].raw == "# Heading 1\n## Sub-Heading 1.1", "Ensure entry includes heading ancestory"
|
||||
assert entries[1].raw == "# Heading 2\n"
|
||||
assert len(entries[1]) == 2
|
||||
assert entries[1][0].raw == "# Heading 1\n## Sub-Heading 1.1", "Ensure entry includes heading ancestory"
|
||||
assert entries[1][1].raw == "# Heading 2\n"
|
||||
|
||||
|
||||
def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
|
||||
|
@ -116,10 +120,11 @@ def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
|
|||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 3
|
||||
assert entries[0].raw == "# Heading 1\n#### Sub-Heading 1.1", "Ensure entry includes heading ancestory"
|
||||
assert entries[1].raw == "# Heading 1\n## Sub-Heading 1.2", "Ensure entry includes heading ancestory"
|
||||
assert entries[2].raw == "# Heading 2\n"
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 3
|
||||
assert entries[1][0].raw == "# Heading 1\n#### Sub-Heading 1.1", "Ensure entry includes heading ancestory"
|
||||
assert entries[1][1].raw == "# Heading 1\n## Sub-Heading 1.2", "Ensure entry includes heading ancestory"
|
||||
assert entries[1][2].raw == "# Heading 2\n"
|
||||
|
||||
|
||||
def test_extract_entries_with_text_before_headings(tmp_path):
|
||||
|
@ -141,10 +146,13 @@ body line 2
|
|||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 3
|
||||
assert entries[0].raw == "\nText before headings"
|
||||
assert entries[1].raw == "# Heading 1\nbody line 1"
|
||||
assert entries[2].raw == "# Heading 1\n## Heading 2\nbody line 2\n", "Ensure raw entry includes heading ancestory"
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 3
|
||||
assert entries[1][0].raw == "\nText before headings"
|
||||
assert entries[1][1].raw == "# Heading 1\nbody line 1"
|
||||
assert (
|
||||
entries[1][2].raw == "# Heading 1\n## Heading 2\nbody line 2\n"
|
||||
), "Ensure raw entry includes heading ancestory"
|
||||
|
||||
|
||||
def test_parse_markdown_file_into_single_entry_if_small(tmp_path):
|
||||
|
@ -165,8 +173,9 @@ body line 1.1
|
|||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert entries[0].raw == entry
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 1
|
||||
assert entries[1][0].raw == entry
|
||||
|
||||
|
||||
def test_parse_markdown_entry_with_children_as_single_entry_if_small(tmp_path):
|
||||
|
@ -191,13 +200,14 @@ longer body line 2.1
|
|||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 3
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 3
|
||||
assert (
|
||||
entries[0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1"
|
||||
entries[1][0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1"
|
||||
), "First entry includes children headings"
|
||||
assert entries[1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings"
|
||||
assert entries[1][1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings"
|
||||
assert (
|
||||
entries[2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n"
|
||||
entries[1][2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n"
|
||||
), "Third entry is second entries child heading"
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from faker import Faker
|
||||
from freezegun import freeze_time
|
||||
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.database.models import Agent, Entry, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_information_sources
|
||||
|
@ -305,6 +305,200 @@ def test_answer_not_known_using_notes_command(client_offline_chat, default_user2
|
|||
assert response_message == prompts.no_notes_found.format()
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_one_file(client_offline_chat, default_user2: KhojUser):
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "Xi Li.markdown" file to the file filters
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
# pick the file that has "Xi Li.markdown" in the name
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "Birthday Gift for Xiu turning 4.markdown" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
query = urllib.parse.quote("/summarize")
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||
assert response_message != "Only one file can be selected for summarization."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_extra_text(client_offline_chat, default_user2: KhojUser):
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "Xi Li.markdown" file to the file filters
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
# pick the file that has "Xi Li.markdown" in the name
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "Birthday Gift for Xiu turning 4.markdown" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
query = urllib.parse.quote("/summarize tell me about Xiu")
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||
assert response_message != "Only one file can be selected for summarization."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_multiple_files(client_offline_chat, default_user2: KhojUser):
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "Xi Li.markdown" file to the file filters
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters", json={"filename": file_list[0], "conversation_id": str(conversation.id)}
|
||||
)
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters", json={"filename": file_list[1], "conversation_id": str(conversation.id)}
|
||||
)
|
||||
|
||||
query = urllib.parse.quote("/summarize")
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response_message == "Only one file can be selected for summarization."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_no_files(client_offline_chat, default_user2: KhojUser):
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
query = urllib.parse.quote("/summarize")
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_different_conversation(client_offline_chat, default_user2: KhojUser):
|
||||
message_list = []
|
||||
conversation1 = create_conversation(message_list, default_user2)
|
||||
conversation2 = create_conversation(message_list, default_user2)
|
||||
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "Birthday Gift for Xiu turning 4.markdown" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
|
||||
# add file filter to conversation 1.
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation1.id)},
|
||||
)
|
||||
|
||||
query = urllib.parse.quote("/summarize")
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation2.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
# now make sure that the file filter is still in conversation 1
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation1.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||
assert response_message != "Only one file can be selected for summarization."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_nonexistant_file(client_offline_chat, default_user2: KhojUser):
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "imaginary.markdown" file to the file filters
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
|
||||
)
|
||||
query = urllib.parse.quote("/summarize")
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_diff_user_file(
|
||||
client_offline_chat, default_user: KhojUser, pdf_configured_user1, default_user2: KhojUser
|
||||
):
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# Get the pdf file called singlepage.pdf
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "singlepage.pdf" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
# add singlepage.pdf to the file filters
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
query = urllib.parse.quote("/summarize")
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
|
@ -522,7 +716,7 @@ async def test_get_correct_tools_online(client_offline_chat):
|
|||
user_query = "What's the weather in Patagonia this week?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -537,7 +731,7 @@ async def test_get_correct_tools_notes(client_offline_chat):
|
|||
user_query = "Where did I go for my first battleship training?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -552,7 +746,7 @@ async def test_get_correct_tools_online_or_general_and_notes(client_offline_chat
|
|||
user_query = "What's the highest point in Patagonia and have I been there?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -569,7 +763,7 @@ async def test_get_correct_tools_general(client_offline_chat):
|
|||
user_query = "How many noble gases are there?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -593,7 +787,7 @@ async def test_get_correct_tools_with_chat_history(client_offline_chat, default_
|
|||
chat_history = create_conversation(chat_log, default_user2)
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, chat_history)
|
||||
tools = await aget_relevant_information_sources(user_query, chat_history, is_task=False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
|
|
@ -5,7 +5,7 @@ from urllib.parse import quote
|
|||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.database.models import Agent, Entry, KhojUser, LocalPdfConfig
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_information_sources
|
||||
|
@ -289,6 +289,198 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default
|
|||
assert response_message == prompts.no_entries_found.format()
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_one_file(chat_client, default_user2: KhojUser):
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "Xi Li.markdown" file to the file filters
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
# pick the file that has "Xi Li.markdown" in the name
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "Birthday Gift for Xiu turning 4.markdown" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
|
||||
response = chat_client.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
query = urllib.parse.quote("/summarize")
|
||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||
assert response_message != "Only one file can be selected for summarization."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_extra_text(chat_client, default_user2: KhojUser):
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "Xi Li.markdown" file to the file filters
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
# pick the file that has "Xi Li.markdown" in the name
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "Birthday Gift for Xiu turning 4.markdown" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
|
||||
response = chat_client.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
query = urllib.parse.quote("/summarize tell me about Xiu")
|
||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||
assert response_message != "Only one file can be selected for summarization."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_multiple_files(chat_client, default_user2: KhojUser):
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "Xi Li.markdown" file to the file filters
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
|
||||
response = chat_client.post(
|
||||
"api/chat/conversation/file-filters", json={"filename": file_list[0], "conversation_id": str(conversation.id)}
|
||||
)
|
||||
response = chat_client.post(
|
||||
"api/chat/conversation/file-filters", json={"filename": file_list[1], "conversation_id": str(conversation.id)}
|
||||
)
|
||||
|
||||
query = urllib.parse.quote("/summarize")
|
||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response_message == "Only one file can be selected for summarization."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_no_files(chat_client, default_user2: KhojUser):
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
query = urllib.parse.quote("/summarize")
|
||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_different_conversation(chat_client, default_user2: KhojUser):
|
||||
message_list = []
|
||||
conversation1 = create_conversation(message_list, default_user2)
|
||||
conversation2 = create_conversation(message_list, default_user2)
|
||||
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "Birthday Gift for Xiu turning 4.markdown" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
|
||||
# add file filter to conversation 1.
|
||||
response = chat_client.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation1.id)},
|
||||
)
|
||||
|
||||
query = urllib.parse.quote("/summarize")
|
||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation2.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
# now make sure that the file filter is still in conversation 1
|
||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation1.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||
assert response_message != "Only one file can be selected for summarization."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_nonexistant_file(chat_client, default_user2: KhojUser):
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "imaginary.markdown" file to the file filters
|
||||
response = chat_client.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
|
||||
)
|
||||
query = urllib.parse.quote("/summarize")
|
||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_diff_user_file(chat_client, default_user: KhojUser, pdf_configured_user1, default_user2: KhojUser):
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# Get the pdf file called singlepage.pdf
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "singlepage.pdf" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
# add singlepage.pdf to the file filters
|
||||
response = chat_client.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
query = urllib.parse.quote("/summarize")
|
||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
|
|
|
@ -33,10 +33,12 @@ def test_configure_indexing_heading_only_entries(tmp_path):
|
|||
# Assert
|
||||
if index_heading_entries:
|
||||
# Entry with empty body indexed when index_heading_entries set to True
|
||||
assert len(entries) == 1
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 1
|
||||
else:
|
||||
# Entry with empty body ignored when index_heading_entries set to False
|
||||
assert is_none_or_empty(entries)
|
||||
assert len(entries) == 2
|
||||
assert is_none_or_empty(entries[1])
|
||||
|
||||
|
||||
def test_entry_split_when_exceeds_max_tokens():
|
||||
|
@ -55,9 +57,9 @@ def test_entry_split_when_exceeds_max_tokens():
|
|||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries = OrgToEntries.extract_org_entries(org_files=data)
|
||||
|
||||
assert len(entries) == 2
|
||||
# Split each entry from specified Org files by max tokens
|
||||
entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=5)
|
||||
entries = TextToEntries.split_entries_by_max_tokens(entries[1], max_tokens=5)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
|
@ -114,11 +116,12 @@ body line 1.1
|
|||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12)
|
||||
for entry in extracted_entries:
|
||||
assert len(extracted_entries) == 2
|
||||
for entry in extracted_entries[1]:
|
||||
entry.raw = clean(entry.raw)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_entries) == 1
|
||||
assert len(extracted_entries[1]) == 1
|
||||
assert entry.raw == expected_entry
|
||||
|
||||
|
||||
|
@ -165,10 +168,11 @@ longer body line 2.1
|
|||
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_entries) == 3
|
||||
assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings"
|
||||
assert extracted_entries[1].compiled == second_expected_entry, "Second entry does not include children headings"
|
||||
assert extracted_entries[2].compiled == third_expected_entry, "Third entry is second entries child heading"
|
||||
assert len(extracted_entries) == 2
|
||||
assert len(extracted_entries[1]) == 3
|
||||
assert extracted_entries[1][0].compiled == first_expected_entry, "First entry includes children headings"
|
||||
assert extracted_entries[1][1].compiled == second_expected_entry, "Second entry does not include children headings"
|
||||
assert extracted_entries[1][2].compiled == third_expected_entry, "Third entry is second entries child heading"
|
||||
|
||||
|
||||
def test_separate_sibling_org_entries_if_all_cannot_fit_in_token_limit(tmp_path):
|
||||
|
@ -226,10 +230,11 @@ body line 3.1
|
|||
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=30)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_entries) == 3
|
||||
assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings"
|
||||
assert extracted_entries[1].compiled == second_expected_entry, "Second entry includes children headings"
|
||||
assert extracted_entries[2].compiled == third_expected_entry, "Third entry includes children headings"
|
||||
assert len(extracted_entries) == 2
|
||||
assert len(extracted_entries[1]) == 3
|
||||
assert extracted_entries[1][0].compiled == first_expected_entry, "First entry includes children headings"
|
||||
assert extracted_entries[1][1].compiled == second_expected_entry, "Second entry includes children headings"
|
||||
assert extracted_entries[1][2].compiled == third_expected_entry, "Third entry includes children headings"
|
||||
|
||||
|
||||
def test_entry_with_body_to_entry(tmp_path):
|
||||
|
@ -251,7 +256,8 @@ def test_entry_with_body_to_entry(tmp_path):
|
|||
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 1
|
||||
|
||||
|
||||
def test_file_with_entry_after_intro_text_to_entry(tmp_path):
|
||||
|
@ -273,6 +279,7 @@ Intro text
|
|||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 2
|
||||
|
||||
|
||||
def test_file_with_no_headings_to_entry(tmp_path):
|
||||
|
@ -291,7 +298,8 @@ def test_file_with_no_headings_to_entry(tmp_path):
|
|||
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 1
|
||||
|
||||
|
||||
def test_get_org_files(tmp_path):
|
||||
|
@ -349,13 +357,14 @@ def test_extract_entries_with_different_level_headings(tmp_path):
|
|||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries = OrgToEntries.extract_org_entries(org_files=data, index_heading_entries=True, max_tokens=3)
|
||||
for entry in entries:
|
||||
assert len(entries) == 2
|
||||
for entry in entries[1]:
|
||||
entry.raw = clean(f"{entry.raw}")
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert entries[0].raw == "* Heading 1\n** Sub-Heading 1.1\n", "Ensure entry includes heading ancestory"
|
||||
assert entries[1].raw == "* Heading 2\n"
|
||||
assert len(entries[1]) == 2
|
||||
assert entries[1][0].raw == "* Heading 1\n** Sub-Heading 1.1\n", "Ensure entry includes heading ancestory"
|
||||
assert entries[1][1].raw == "* Heading 2\n"
|
||||
|
||||
|
||||
# Helper Functions
|
||||
|
|
|
@ -17,7 +17,8 @@ def test_single_page_pdf_to_jsonl():
|
|||
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 1
|
||||
|
||||
|
||||
def test_multi_page_pdf_to_jsonl():
|
||||
|
@ -31,7 +32,8 @@ def test_multi_page_pdf_to_jsonl():
|
|||
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 6
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 6
|
||||
|
||||
|
||||
def test_ocr_page_pdf_to_jsonl():
|
||||
|
@ -43,9 +45,9 @@ def test_ocr_page_pdf_to_jsonl():
|
|||
|
||||
data = {"tests/data/pdf/ocr_samples.pdf": pdf_bytes}
|
||||
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
assert len(entries) == 1
|
||||
assert "playing on a strip of marsh" in entries[0].raw
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 1
|
||||
assert "playing on a strip of marsh" in entries[1][0].raw
|
||||
|
||||
|
||||
def test_get_pdf_files(tmp_path):
|
||||
|
|
|
@ -25,15 +25,16 @@ def test_plaintext_file(tmp_path):
|
|||
entries = PlaintextToEntries.extract_plaintext_entries(data)
|
||||
|
||||
# Convert each entry.file to absolute path to make them JSON serializable
|
||||
for entry in entries:
|
||||
for entry in entries[1]:
|
||||
entry.file = str(Path(entry.file).absolute())
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 1
|
||||
# Ensure raw entry with no headings do not get heading prefix prepended
|
||||
assert not entries[0].raw.startswith("#")
|
||||
assert not entries[1][0].raw.startswith("#")
|
||||
# Ensure compiled entry has filename prepended as top level heading
|
||||
assert entries[0].compiled == f"{plaintextfile}\n{raw_entry}"
|
||||
assert entries[1][0].compiled == f"{plaintextfile}\n{raw_entry}"
|
||||
|
||||
|
||||
def test_get_plaintext_files(tmp_path):
|
||||
|
@ -94,8 +95,9 @@ def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
|
|||
entries = PlaintextToEntries.extract_plaintext_entries(extracted_plaintext_files)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert "<div>" not in entries[0].raw
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 1
|
||||
assert "<div>" not in entries[1][0].raw
|
||||
|
||||
|
||||
def test_large_plaintext_file_split_into_multiple_entries(tmp_path):
|
||||
|
@ -113,13 +115,20 @@ def test_large_plaintext_file_split_into_multiple_entries(tmp_path):
|
|||
|
||||
# Act
|
||||
# Extract Entries from specified plaintext files
|
||||
normal_entries = PlaintextToEntries.extract_plaintext_entries(normal_data)
|
||||
large_entries = PlaintextToEntries.extract_plaintext_entries(large_data)
|
||||
|
||||
# assert
|
||||
assert len(normal_entries) == 2
|
||||
assert len(large_entries) == 2
|
||||
|
||||
normal_entries = PlaintextToEntries.split_entries_by_max_tokens(
|
||||
PlaintextToEntries.extract_plaintext_entries(normal_data),
|
||||
normal_entries[1],
|
||||
max_tokens=max_tokens,
|
||||
raw_is_compiled=True,
|
||||
)
|
||||
large_entries = PlaintextToEntries.split_entries_by_max_tokens(
|
||||
PlaintextToEntries.extract_plaintext_entries(large_data), max_tokens=max_tokens, raw_is_compiled=True
|
||||
large_entries[1], max_tokens=max_tokens, raw_is_compiled=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
|
Loading…
Reference in a new issue