Update Text Chunking Strategy to Improve Search Context (#645)

## Major
- Parse markdown, org parent entries as single entry if fit within max tokens
- Parse a file as single entry if it fits with max token limits
- Add parent heading ancestry to extracted markdown entries for context
- Chunk text in preference order of para, sentence, word, character

## Minor
- Create wrapper function to get entries from org, md, pdf & text files
- Remove unused Entry to Jsonl converter from text to entry class, tests
- Dedupe code by using single func to process an org file into entries

Resolves #620
This commit is contained in:
Debanjum 2024-04-08 13:56:38 +05:30 committed by GitHub
commit 11ce3e2268
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 704 additions and 393 deletions

View file

@ -1,14 +1,13 @@
import logging
import re
from pathlib import Path
from typing import List, Tuple
from typing import Dict, List, Tuple
import urllib3
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils.constants import empty_escape_sequences
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry
@ -31,15 +30,14 @@ class MarkdownToEntries(TextToEntries):
else:
deletion_file_names = None
max_tokens = 256
# Extract Entries from specified Markdown files
with timer("Parse entries from Markdown files into dictionaries", logger):
current_entries = MarkdownToEntries.convert_markdown_entries_to_maps(
*MarkdownToEntries.extract_markdown_entries(files)
)
with timer("Extract entries from specified Markdown files", logger):
current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens)
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
@ -57,48 +55,84 @@ class MarkdownToEntries(TextToEntries):
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_markdown_entries(markdown_files):
def extract_markdown_entries(markdown_files, max_tokens=256) -> List[Entry]:
"Extract entries by heading from specified Markdown files"
# Regex to extract Markdown Entries by Heading
entries = []
entry_to_file_map = []
entries: List[str] = []
entry_to_file_map: List[Tuple[str, str]] = []
for markdown_file in markdown_files:
try:
markdown_content = markdown_files[markdown_file]
entries, entry_to_file_map = MarkdownToEntries.process_single_markdown_file(
markdown_content, markdown_file, entries, entry_to_file_map
markdown_content, markdown_file, entries, entry_to_file_map, max_tokens
)
except Exception as e:
logger.warning(f"Unable to process file: {markdown_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
logger.error(
f"Unable to process file: {markdown_file}. This file will not be indexed.\n{e}", exc_info=True
)
return entries, dict(entry_to_file_map)
return MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map))
@staticmethod
def process_single_markdown_file(
markdown_content: str, markdown_file: Path, entries: List, entry_to_file_map: List
):
markdown_heading_regex = r"^#"
markdown_content: str,
markdown_file: str,
entries: List[str],
entry_to_file_map: List[Tuple[str, str]],
max_tokens=256,
ancestry: Dict[int, str] = {},
) -> Tuple[List[str], List[Tuple[str, str]]]:
# Prepend the markdown section's heading ancestry
ancestry_string = "\n".join([f"{'#' * key} {ancestry[key]}" for key in sorted(ancestry.keys())])
markdown_content_with_ancestry = f"{ancestry_string}{markdown_content}"
markdown_entries_per_file = []
any_headings = re.search(markdown_heading_regex, markdown_content, flags=re.MULTILINE)
for entry in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE):
# Add heading level as the regex split removed it from entries with headings
prefix = "#" if entry.startswith("#") else "# " if any_headings else ""
stripped_entry = entry.strip(empty_escape_sequences)
if stripped_entry != "":
markdown_entries_per_file.append(f"{prefix}{stripped_entry}")
# If content is small or content has no children headings, save it as a single entry
if len(TextToEntries.tokenizer(markdown_content_with_ancestry)) <= max_tokens or not re.search(
rf"^#{{{len(ancestry)+1},}}\s", markdown_content, flags=re.MULTILINE
):
entry_to_file_map += [(markdown_content_with_ancestry, markdown_file)]
entries.extend([markdown_content_with_ancestry])
return entries, entry_to_file_map
# Split by next heading level present in the entry
next_heading_level = len(ancestry)
sections: List[str] = []
while len(sections) < 2:
next_heading_level += 1
sections = re.split(rf"(\n|^)(?=[#]{{{next_heading_level}}} .+\n?)", markdown_content, flags=re.MULTILINE)
for section in sections:
# Skip empty sections
if section.strip() == "":
continue
# Extract the section body and (when present) the heading
current_ancestry = ancestry.copy()
first_line = [line for line in section.split("\n") if line.strip() != ""][0]
if re.search(rf"^#{{{next_heading_level}}} ", first_line):
# Extract the section body without the heading
current_section_body = "\n".join(section.split(first_line)[1:])
# Parse the section heading into current section ancestry
current_section_title = first_line[next_heading_level:].strip()
current_ancestry[next_heading_level] = current_section_title
else:
current_section_body = section
# Recurse down children of the current entry
MarkdownToEntries.process_single_markdown_file(
current_section_body,
markdown_file,
entries,
entry_to_file_map,
max_tokens,
current_ancestry,
)
entry_to_file_map += zip(markdown_entries_per_file, [markdown_file] * len(markdown_entries_per_file))
entries.extend(markdown_entries_per_file)
return entries, entry_to_file_map
@staticmethod
def convert_markdown_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
"Convert each Markdown entries into a dictionary"
entries = []
entries: List[Entry] = []
for parsed_entry in parsed_entries:
raw_filename = entry_to_file_map[parsed_entry]
@ -108,13 +142,12 @@ class MarkdownToEntries(TextToEntries):
entry_filename = urllib3.util.parse_url(raw_filename).url
else:
entry_filename = str(Path(raw_filename))
stem = Path(raw_filename).stem
heading = parsed_entry.splitlines()[0] if re.search("^#+\s", parsed_entry) else ""
# Append base filename to compiled entry for context to model
# Increment heading level for heading entries and make filename as its top level heading
prefix = f"# {stem}\n#" if heading else f"# {stem}\n"
compiled_entry = f"{entry_filename}\n{prefix}{parsed_entry}"
prefix = f"# {entry_filename}\n#" if heading else f"# {entry_filename}\n"
compiled_entry = f"{prefix}{parsed_entry}"
entries.append(
Entry(
compiled=compiled_entry,
@ -127,8 +160,3 @@ class MarkdownToEntries(TextToEntries):
logger.debug(f"Converted {len(parsed_entries)} markdown entries to dictionaries")
return entries
@staticmethod
def convert_markdown_maps_to_jsonl(entries: List[Entry]):
"Convert each Markdown entry to JSON and collate as JSONL"
return "".join([f"{entry.to_json()}\n" for entry in entries])

View file

@ -1,10 +1,12 @@
import logging
import re
from pathlib import Path
from typing import Iterable, List, Tuple
from typing import Dict, List, Tuple
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.org_mode import orgnode
from khoj.processor.content.org_mode.orgnode import Orgnode
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils import state
from khoj.utils.helpers import timer
@ -21,9 +23,6 @@ class OrgToEntries(TextToEntries):
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config
index_heading_entries = False
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
@ -32,14 +31,12 @@ class OrgToEntries(TextToEntries):
deletion_file_names = None
# Extract Entries from specified Org files
with timer("Parse entries from org files into OrgNode objects", logger):
entry_nodes, file_to_entries = self.extract_org_entries(files)
with timer("Convert OrgNodes into list of entries", logger):
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
max_tokens = 256
with timer("Extract entries from specified Org files", logger):
current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=max_tokens)
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
@ -57,93 +54,165 @@ class OrgToEntries(TextToEntries):
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_org_entries(org_files: dict[str, str]):
def extract_org_entries(
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=256
) -> List[Entry]:
"Extract entries from specified Org files"
entries = []
entry_to_file_map: List[Tuple[orgnode.Orgnode, str]] = []
entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
return OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map, index_heading_entries)
@staticmethod
def extract_org_nodes(org_files: dict[str, str], max_tokens) -> Tuple[List[List[Orgnode]], Dict[Orgnode, str]]:
"Extract org nodes from specified org files"
entries: List[List[Orgnode]] = []
entry_to_file_map: List[Tuple[Orgnode, str]] = []
for org_file in org_files:
filename = org_file
file = org_files[org_file]
try:
org_file_entries = orgnode.makelist(file, filename)
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
entries.extend(org_file_entries)
org_content = org_files[org_file]
entries, entry_to_file_map = OrgToEntries.process_single_org_file(
org_content, org_file, entries, entry_to_file_map, max_tokens
)
except Exception as e:
logger.warning(f"Unable to process file: {org_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
logger.error(f"Unable to process file: {org_file}. Skipped indexing it.\nError; {e}", exc_info=True)
return entries, dict(entry_to_file_map)
@staticmethod
def process_single_org_file(org_content: str, org_file: str, entries: List, entry_to_file_map: List):
# Process single org file. The org parser assumes that the file is a single org file and reads it from a buffer. We'll split the raw conetnt of this file by new line to mimic the same behavior.
try:
org_file_entries = orgnode.makelist(org_content, org_file)
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
entries.extend(org_file_entries)
return entries, entry_to_file_map
except Exception as e:
logger.error(f"Error processing file: {org_file} with error: {e}", exc_info=True)
def process_single_org_file(
org_content: str,
org_file: str,
entries: List[List[Orgnode]],
entry_to_file_map: List[Tuple[Orgnode, str]],
max_tokens=256,
ancestry: Dict[int, str] = {},
) -> Tuple[List[List[Orgnode]], List[Tuple[Orgnode, str]]]:
"""Parse org_content from org_file into OrgNode entries
Recurse down org file entries, one heading level at a time,
until reach a leaf entry or the current entry tree fits max_tokens.
Parse recursion terminating entry (trees) into (a list of) OrgNode objects.
"""
# Prepend the org section's heading ancestry
ancestry_string = "\n".join([f"{'*' * key} {ancestry[key]}" for key in sorted(ancestry.keys())])
org_content_with_ancestry = f"{ancestry_string}{org_content}"
# If content is small or content has no children headings, save it as a single entry
# Note: This is the terminating condition for this recursive function
if len(TextToEntries.tokenizer(org_content_with_ancestry)) <= max_tokens or not re.search(
rf"^\*{{{len(ancestry)+1},}}\s", org_content, re.MULTILINE
):
orgnode_content_with_ancestry = orgnode.makelist(org_content_with_ancestry, org_file)
entry_to_file_map += zip(orgnode_content_with_ancestry, [org_file] * len(orgnode_content_with_ancestry))
entries.extend([orgnode_content_with_ancestry])
return entries, entry_to_file_map
# Split this entry tree into sections by the next heading level in it
# Increment heading level until able to split entry into sections
# A successful split will result in at least 2 sections
next_heading_level = len(ancestry)
sections: List[str] = []
while len(sections) < 2:
next_heading_level += 1
sections = re.split(rf"(\n|^)(?=[*]{{{next_heading_level}}} .+\n?)", org_content, flags=re.MULTILINE)
# Recurse down each non-empty section after parsing its body, heading and ancestry
for section in sections:
# Skip empty sections
if section.strip() == "":
continue
# Extract the section body and (when present) the heading
current_ancestry = ancestry.copy()
first_non_empty_line = [line for line in section.split("\n") if line.strip() != ""][0]
# If first non-empty line is a heading with expected heading level
if re.search(rf"^\*{{{next_heading_level}}}\s", first_non_empty_line):
# Extract the section body without the heading
current_section_body = "\n".join(section.split(first_non_empty_line)[1:])
# Parse the section heading into current section ancestry
current_section_title = first_non_empty_line[next_heading_level:].strip()
current_ancestry[next_heading_level] = current_section_title
# Else process the section as just body text
else:
current_section_body = section
# Recurse down children of the current entry
OrgToEntries.process_single_org_file(
current_section_body,
org_file,
entries,
entry_to_file_map,
max_tokens,
current_ancestry,
)
return entries, entry_to_file_map
@staticmethod
def convert_org_nodes_to_entries(
parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False
parsed_entries: List[List[Orgnode]],
entry_to_file_map: Dict[Orgnode, str],
index_heading_entries: bool = False,
) -> List[Entry]:
"Convert Org-Mode nodes into list of Entry objects"
"""
Convert OrgNode lists into list of Entry objects
Each list of OrgNodes is a parsed parent org tree or leaf node.
Convert each list of these OrgNodes into a single Entry.
"""
entries: List[Entry] = []
for parsed_entry in parsed_entries:
if not parsed_entry.hasBody and not index_heading_entries:
# Ignore title notes i.e notes with just headings and empty body
continue
for entry_group in parsed_entries:
entry_heading, entry_compiled, entry_raw = "", "", ""
for parsed_entry in entry_group:
if not parsed_entry.hasBody and not index_heading_entries:
# Ignore title notes i.e notes with just headings and empty body
continue
todo_str = f"{parsed_entry.todo} " if parsed_entry.todo else ""
todo_str = f"{parsed_entry.todo} " if parsed_entry.todo else ""
# Prepend ancestor headings, filename as top heading to entry for context
ancestors_trail = " / ".join(parsed_entry.ancestors) or Path(entry_to_file_map[parsed_entry])
if parsed_entry.heading:
heading = f"* Path: {ancestors_trail}\n** {todo_str}{parsed_entry.heading}."
else:
heading = f"* Path: {ancestors_trail}."
# Set base level to current org-node tree's root heading level
if not entry_heading and parsed_entry.level > 0:
base_level = parsed_entry.level
# Indent entry by 1 heading level as ancestry is prepended as top level heading
heading = f"{'*' * (parsed_entry.level-base_level+2)} {todo_str}" if parsed_entry.level > 0 else ""
if parsed_entry.heading:
heading += f"{parsed_entry.heading}."
compiled = heading
if state.verbose > 2:
logger.debug(f"Title: {heading}")
# Prepend ancestor headings, filename as top heading to root parent entry for context
# Children nodes do not need ancestors trail as root parent node will have it
if not entry_heading:
ancestors_trail = " / ".join(parsed_entry.ancestors) or Path(entry_to_file_map[parsed_entry])
heading = f"* Path: {ancestors_trail}\n{heading}" if heading else f"* Path: {ancestors_trail}."
if parsed_entry.tags:
tags_str = " ".join(parsed_entry.tags)
compiled += f"\t {tags_str}."
if state.verbose > 2:
logger.debug(f"Tags: {tags_str}")
compiled = heading
if parsed_entry.closed:
compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.'
if state.verbose > 2:
logger.debug(f'Closed: {parsed_entry.closed.strftime("%Y-%m-%d")}')
if parsed_entry.tags:
tags_str = " ".join(parsed_entry.tags)
compiled += f"\t {tags_str}."
if parsed_entry.scheduled:
compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.'
if state.verbose > 2:
logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}')
if parsed_entry.closed:
compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.'
if parsed_entry.hasBody:
compiled += f"\n {parsed_entry.body}"
if state.verbose > 2:
logger.debug(f"Body: {parsed_entry.body}")
if parsed_entry.scheduled:
compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.'
if compiled:
if parsed_entry.hasBody:
compiled += f"\n {parsed_entry.body}"
# Add the sub-entry contents to the entry
entry_compiled += f"{compiled}"
entry_raw += f"{parsed_entry}"
if not entry_heading:
entry_heading = heading
if entry_compiled:
entries.append(
Entry(
compiled=compiled,
raw=f"{parsed_entry}",
heading=f"{heading}",
compiled=entry_compiled,
raw=entry_raw,
heading=f"{entry_heading}",
file=f"{entry_to_file_map[parsed_entry]}",
)
)
return entries
@staticmethod
def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str:
"Convert each Org-Mode entry to JSON and collate as JSONL"
return "".join([f"{entry_dict.to_json()}\n" for entry_dict in entries])

View file

@ -37,7 +37,7 @@ import datetime
import re
from os.path import relpath
from pathlib import Path
from typing import List
from typing import Dict, List, Tuple
indent_regex = re.compile(r"^ *")
@ -58,7 +58,7 @@ def makelist_with_filepath(filename):
return makelist(f, filename)
def makelist(file, filename):
def makelist(file, filename) -> List["Orgnode"]:
"""
Read an org-mode file and return a list of Orgnode objects
created from this file.
@ -80,16 +80,16 @@ def makelist(file, filename):
} # populated from #+SEQ_TODO line
level = ""
heading = ""
ancestor_headings = []
ancestor_headings: List[str] = []
bodytext = ""
introtext = ""
tags = list() # set of all tags in headline
closed_date = ""
sched_date = ""
deadline_date = ""
logbook = list()
tags: List[str] = list() # set of all tags in headline
closed_date: datetime.date = None
sched_date: datetime.date = None
deadline_date: datetime.date = None
logbook: List[Tuple[datetime.datetime, datetime.datetime]] = list()
nodelist: List[Orgnode] = list()
property_map = dict()
property_map: Dict[str, str] = dict()
in_properties_drawer = False
in_logbook_drawer = False
file_title = f"{filename}"
@ -102,13 +102,13 @@ def makelist(file, filename):
thisNode = Orgnode(level, heading, bodytext, tags, ancestor_headings)
if closed_date:
thisNode.closed = closed_date
closed_date = ""
closed_date = None
if sched_date:
thisNode.scheduled = sched_date
sched_date = ""
sched_date = None
if deadline_date:
thisNode.deadline = deadline_date
deadline_date = ""
deadline_date = None
if logbook:
thisNode.logbook = logbook
logbook = list()
@ -116,7 +116,7 @@ def makelist(file, filename):
nodelist.append(thisNode)
property_map = {"LINE": f"file:{normalize_filename(filename)}::{ctr}"}
previous_level = level
previous_heading = heading
previous_heading: str = heading
level = heading_search.group(1)
heading = heading_search.group(2)
bodytext = ""
@ -495,12 +495,13 @@ class Orgnode(object):
if self._priority:
n = n + "[#" + self._priority + "] "
n = n + self._heading
n = "%-60s " % n # hack - tags will start in column 62
closecolon = ""
for t in self._tags:
n = n + ":" + t
closecolon = ":"
n = n + closecolon
if self._tags:
n = "%-60s " % n # hack - tags will start in column 62
closecolon = ""
for t in self._tags:
n = n + ":" + t
closecolon = ":"
n = n + closecolon
n = n + "\n"
# Get body indentation from first line of body

View file

@ -32,8 +32,8 @@ class PdfToEntries(TextToEntries):
deletion_file_names = None
# Extract Entries from specified Pdf files
with timer("Parse entries from PDF files into dictionaries", logger):
current_entries = PdfToEntries.convert_pdf_entries_to_maps(*PdfToEntries.extract_pdf_entries(files))
with timer("Extract entries from specified PDF files", logger):
current_entries = PdfToEntries.extract_pdf_entries(files)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@ -55,11 +55,11 @@ class PdfToEntries(TextToEntries):
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_pdf_entries(pdf_files):
def extract_pdf_entries(pdf_files) -> List[Entry]:
"""Extract entries by page from specified PDF files"""
entries = []
entry_to_location_map = []
entries: List[str] = []
entry_to_location_map: List[Tuple[str, str]] = []
for pdf_file in pdf_files:
try:
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path
@ -83,7 +83,7 @@ class PdfToEntries(TextToEntries):
if os.path.exists(f"{tmp_file}"):
os.remove(f"{tmp_file}")
return entries, dict(entry_to_location_map)
return PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
@staticmethod
def convert_pdf_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
@ -106,8 +106,3 @@ class PdfToEntries(TextToEntries):
logger.debug(f"Converted {len(parsed_entries)} PDF entries to dictionaries")
return entries
@staticmethod
def convert_pdf_maps_to_jsonl(entries: List[Entry]):
"Convert each PDF entry to JSON and collate as JSONL"
return "".join([f"{entry.to_json()}\n" for entry in entries])

View file

@ -42,8 +42,8 @@ class PlaintextToEntries(TextToEntries):
logger.warning(e, exc_info=True)
# Extract Entries from specified plaintext files
with timer("Parse entries from plaintext files", logger):
current_entries = PlaintextToEntries.convert_plaintext_entries_to_maps(files)
with timer("Parse entries from specified Plaintext files", logger):
current_entries = PlaintextToEntries.extract_plaintext_entries(files)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@ -74,7 +74,7 @@ class PlaintextToEntries(TextToEntries):
return soup.get_text(strip=True, separator="\n")
@staticmethod
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
def extract_plaintext_entries(entry_to_file_map: dict[str, str]) -> List[Entry]:
"Convert each plaintext entries into a dictionary"
entries = []
for file, entry in entry_to_file_map.items():
@ -87,8 +87,3 @@ class PlaintextToEntries(TextToEntries):
)
)
return entries
@staticmethod
def convert_entries_to_jsonl(entries: List[Entry]):
"Convert each entry to JSON and collate as JSONL"
return "".join([f"{entry.to_json()}\n" for entry in entries])

View file

@ -1,10 +1,12 @@
import hashlib
import logging
import re
import uuid
from abc import ABC, abstractmethod
from itertools import repeat
from typing import Any, Callable, List, Set, Tuple
from langchain.text_splitter import RecursiveCharacterTextSplitter
from tqdm import tqdm
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
@ -34,6 +36,27 @@ class TextToEntries(ABC):
def hash_func(key: str) -> Callable:
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest()
@staticmethod
def remove_long_words(text: str, max_word_length: int = 500) -> str:
"Remove words longer than max_word_length from text."
# Split the string by words, keeping the delimiters
splits = re.split(r"(\s+)", text) + [""]
words_with_delimiters = list(zip(splits[::2], splits[1::2]))
# Filter out long words while preserving delimiters in text
filtered_text = [
f"{word}{delimiter}"
for word, delimiter in words_with_delimiters
if not word.strip() or len(word.strip()) <= max_word_length
]
return "".join(filtered_text)
@staticmethod
def tokenizer(text: str) -> List[str]:
"Tokenize text into words."
return text.split()
@staticmethod
def split_entries_by_max_tokens(
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500
@ -44,24 +67,30 @@ class TextToEntries(ABC):
if is_none_or_empty(entry.compiled):
continue
# Split entry into words
compiled_entry_words = [word for word in entry.compiled.split(" ") if word != ""]
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
# Split entry into chunks of max_tokens
# Use chunking preference order: paragraphs > sentences > words > characters
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=max_tokens,
separators=["\n\n", "\n", "!", "?", ".", " ", "\t", ""],
keep_separator=True,
length_function=lambda chunk: len(TextToEntries.tokenizer(chunk)),
chunk_overlap=0,
)
chunked_entry_chunks = text_splitter.split_text(entry.compiled)
corpus_id = uuid.uuid4()
# Split entry into chunks of max tokens
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
compiled_entry_words_chunk = compiled_entry_words[chunk_index : chunk_index + max_tokens]
compiled_entry_chunk = " ".join(compiled_entry_words_chunk)
# Create heading prefixed entry from each chunk
for chunk_index, compiled_entry_chunk in enumerate(chunked_entry_chunks):
# Prepend heading to all other chunks, the first chunk already has heading from original entry
if chunk_index > 0:
if chunk_index > 0 and entry.heading:
# Snip heading to avoid crossing max_tokens limit
# Keep last 100 characters of heading as entry heading more important than filename
snipped_heading = entry.heading[-100:]
compiled_entry_chunk = f"{snipped_heading}.\n{compiled_entry_chunk}"
# Prepend snipped heading
compiled_entry_chunk = f"{snipped_heading}\n{compiled_entry_chunk}"
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
compiled_entry_chunk = TextToEntries.remove_long_words(compiled_entry_chunk, max_word_length)
# Clean entry of unwanted characters like \0 character
compiled_entry_chunk = TextToEntries.clean_field(compiled_entry_chunk)
@ -160,7 +189,7 @@ class TextToEntries(ABC):
new_dates = []
with timer("Indexed dates from added entries in", logger):
for added_entry in added_entries:
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.raw), repeat(added_entry))
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.compiled), repeat(added_entry))
dates_to_create = [
EntryDates(date=date, entry=added_entry)
for date, added_entry in dates_in_entries
@ -244,11 +273,6 @@ class TextToEntries(ABC):
return entries_with_ids
@staticmethod
def convert_text_maps_to_jsonl(entries: List[Entry]) -> str:
# Convert each entry to JSON and write to JSONL file
return "".join([f"{entry.to_json()}\n" for entry in entries])
@staticmethod
def clean_field(field: str) -> str:
return field.replace("\0", "") if not is_none_or_empty(field) else ""

View file

@ -489,7 +489,7 @@ async def chat(
common: CommonQueryParams,
q: str,
n: Optional[int] = 5,
d: Optional[float] = 0.18,
d: Optional[float] = 0.22,
stream: Optional[bool] = False,
title: Optional[str] = None,
conversation_id: Optional[int] = None,

View file

@ -306,7 +306,7 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa
user_query = quote("How to git install application?")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.22", headers=headers)
# Assert
assert response.status_code == 200
@ -325,7 +325,7 @@ def test_notes_search_no_results(client, search_config: SearchConfig, sample_org
user_query = quote("How to find my goat?")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.22", headers=headers)
# Assert
assert response.status_code == 200
@ -409,7 +409,7 @@ def test_notes_search_requires_parent_context(
user_query = quote("Install Khoj on Emacs")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.22", headers=headers)
# Assert
assert response.status_code == 200

View file

@ -1,4 +1,3 @@
import json
import os
from pathlib import Path
@ -7,8 +6,8 @@ from khoj.utils.fs_syncer import get_markdown_files
from khoj.utils.rawconfig import TextContentConfig
def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
"Convert files with no heading to jsonl."
def test_extract_markdown_with_no_headings(tmp_path):
"Convert markdown file with no heading to entry format."
# Arrange
entry = f"""
- Bullet point 1
@ -17,30 +16,24 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
data = {
f"{tmp_path}": entry,
}
expected_heading = f"# {tmp_path.stem}"
expected_heading = f"# {tmp_path}"
# Act
# Extract Entries from specified Markdown files
entry_nodes, file_to_entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
# Process Each Entry from All Notes Files
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(
MarkdownToEntries.convert_markdown_entries_to_maps(entry_nodes, file_to_entries)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
assert len(jsonl_data) == 1
assert len(entries) == 1
# Ensure raw entry with no headings do not get heading prefix prepended
assert not jsonl_data[0]["raw"].startswith("#")
assert not entries[0].raw.startswith("#")
# Ensure compiled entry has filename prepended as top level heading
assert expected_heading in jsonl_data[0]["compiled"]
assert entries[0].compiled.startswith(expected_heading)
# Ensure compiled entry also includes the file name
assert str(tmp_path) in jsonl_data[0]["compiled"]
assert str(tmp_path) in entries[0].compiled
def test_single_markdown_entry_to_jsonl(tmp_path):
"Convert markdown entry from single file to jsonl."
def test_extract_single_markdown_entry(tmp_path):
"Convert markdown from single file to entry format."
# Arrange
entry = f"""### Heading
\t\r
@ -52,20 +45,14 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Markdown files
entries, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
# Process Each Entry from All Notes Files
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(
MarkdownToEntries.convert_markdown_entries_to_maps(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
assert len(jsonl_data) == 1
assert len(entries) == 1
def test_multiple_markdown_entries_to_jsonl(tmp_path):
"Convert multiple markdown entries from single file to jsonl."
def test_extract_multiple_markdown_entries(tmp_path):
"Convert multiple markdown from single file to entry format."
# Arrange
entry = f"""
### Heading 1
@ -81,19 +68,139 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Markdown files
entry_strings, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
entries = MarkdownToEntries.convert_markdown_entries_to_maps(entry_strings, entry_to_file_map)
# Process Each Entry from All Notes Files
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
assert len(jsonl_data) == 2
assert len(entries) == 2
# Ensure entry compiled strings include the markdown files they originate from
assert all([tmp_path.stem in entry.compiled for entry in entries])
def test_extract_entries_with_different_level_headings(tmp_path):
"Extract markdown entries with different level headings."
# Arrange
entry = f"""
# Heading 1
## Sub-Heading 1.1
# Heading 2
"""
data = {
f"{tmp_path}": entry,
}
# Act
# Extract Entries from specified Markdown files
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
assert len(entries) == 2
assert entries[0].raw == "# Heading 1\n## Sub-Heading 1.1", "Ensure entry includes heading ancestory"
assert entries[1].raw == "# Heading 2\n"
def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
"Extract markdown entries when deeper child level before shallower child level."
# Arrange
entry = f"""
# Heading 1
#### Sub-Heading 1.1
## Sub-Heading 1.2
# Heading 2
"""
data = {
f"{tmp_path}": entry,
}
# Act
# Extract Entries from specified Markdown files
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
assert len(entries) == 3
assert entries[0].raw == "# Heading 1\n#### Sub-Heading 1.1", "Ensure entry includes heading ancestory"
assert entries[1].raw == "# Heading 1\n## Sub-Heading 1.2", "Ensure entry includes heading ancestory"
assert entries[2].raw == "# Heading 2\n"
def test_extract_entries_with_text_before_headings(tmp_path):
"Extract markdown entries with some text before any headings."
# Arrange
entry = f"""
Text before headings
# Heading 1
body line 1
## Heading 2
body line 2
"""
data = {
f"{tmp_path}": entry,
}
# Act
# Extract Entries from specified Markdown files
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
assert len(entries) == 3
assert entries[0].raw == "\nText before headings"
assert entries[1].raw == "# Heading 1\nbody line 1"
assert entries[2].raw == "# Heading 1\n## Heading 2\nbody line 2\n", "Ensure raw entry includes heading ancestory"
def test_parse_markdown_file_into_single_entry_if_small(tmp_path):
"Parse markdown file into single entry if it fits within the token limits."
# Arrange
entry = f"""
# Heading 1
body line 1
## Subheading 1.1
body line 1.1
"""
data = {
f"{tmp_path}": entry,
}
# Act
# Extract Entries from specified Markdown files
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
# Assert
assert len(entries) == 1
assert entries[0].raw == entry
def test_parse_markdown_entry_with_children_as_single_entry_if_small(tmp_path):
"Parse markdown entry with child headings as single entry if it fits within the tokens limits."
# Arrange
entry = f"""
# Heading 1
body line 1
## Subheading 1.1
body line 1.1
# Heading 2
body line 2
## Subheading 2.1
longer body line 2.1
"""
data = {
f"{tmp_path}": entry,
}
# Act
# Extract Entries from specified Markdown files
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
# Assert
assert len(entries) == 3
assert (
entries[0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1"
), "First entry includes children headings"
assert entries[1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings"
assert (
entries[2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n"
), "Third entry is second entries child heading"
def test_get_markdown_files(tmp_path):
"Ensure Markdown files specified via input-filter, input-files extracted"
# Arrange
@ -131,27 +238,6 @@ def test_get_markdown_files(tmp_path):
assert set(extracted_org_files.keys()) == expected_files
def test_extract_entries_with_different_level_headings(tmp_path):
"Extract markdown entries with different level headings."
# Arrange
entry = f"""
# Heading 1
## Heading 2
"""
data = {
f"{tmp_path}": entry,
}
# Act
# Extract Entries from specified Markdown files
entries, _ = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
# Assert
assert len(entries) == 2
assert entries[0] == "# Heading 1"
assert entries[1] == "## Heading 2"
# Helper Functions
def create_file(tmp_path: Path, entry=None, filename="test.md"):
markdown_file = tmp_path / filename

View file

@ -56,7 +56,7 @@ def test_index_update_with_user2_inaccessible_user1(client, api_user2: KhojApiUs
# Assert
assert update_response.status_code == 200
assert len(results) == 5
assert len(results) == 3
for result in results:
assert result["additional"]["file"] not in source_file_symbol

View file

@ -470,10 +470,6 @@ async def test_websearch_with_operators(chat_client):
["site:reddit.com" in response for response in responses]
), "Expected a search query to include site:reddit.com but got: " + str(responses)
assert any(
["after:2024/04/01" in response for response in responses]
), "Expected a search query to include after:2024/04/01 but got: " + str(responses)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio

View file

@ -1,5 +1,5 @@
import json
import os
import re
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.text_to_entries import TextToEntries
@ -8,7 +8,7 @@ from khoj.utils.helpers import is_none_or_empty
from khoj.utils.rawconfig import Entry, TextContentConfig
def test_configure_heading_entry_to_jsonl(tmp_path):
def test_configure_indexing_heading_only_entries(tmp_path):
"""Ensure entries with empty body are ignored, unless explicitly configured to index heading entries.
Property drawers not considered Body. Ignore control characters for evaluating if Body empty."""
# Arrange
@ -26,24 +26,21 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
for index_heading_entries in [True, False]:
# Act
# Extract entries into jsonl from specified Org files
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
OrgToEntries.convert_org_nodes_to_entries(
*OrgToEntries.extract_org_entries(org_files=data), index_heading_entries=index_heading_entries
)
entries = OrgToEntries.extract_org_entries(
org_files=data, index_heading_entries=index_heading_entries, max_tokens=3
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
if index_heading_entries:
# Entry with empty body indexed when index_heading_entries set to True
assert len(jsonl_data) == 1
assert len(entries) == 1
else:
# Entry with empty body ignored when index_heading_entries set to False
assert is_none_or_empty(jsonl_data)
assert is_none_or_empty(entries)
def test_entry_split_when_exceeds_max_words():
"Ensure entries with compiled words exceeding max_words are split."
def test_entry_split_when_exceeds_max_tokens():
"Ensure entries with compiled words exceeding max_tokens are split."
# Arrange
tmp_path = "/tmp/test.org"
entry = f"""*** Heading
@ -57,29 +54,26 @@ def test_entry_split_when_exceeds_max_words():
# Act
# Extract Entries from specified Org files
entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data)
entries = OrgToEntries.extract_org_entries(org_files=data)
# Split each entry from specified Org files by max words
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
TextToEntries.split_entries_by_max_tokens(
OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Split each entry from specified Org files by max tokens
entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=6)
# Assert
assert len(jsonl_data) == 2
# Ensure compiled entries split by max_words start with entry heading (for search context)
assert all([entry["compiled"].startswith(expected_heading) for entry in jsonl_data])
assert len(entries) == 2
# Ensure compiled entries split by max tokens start with entry heading (for search context)
assert all([entry.compiled.startswith(expected_heading) for entry in entries])
def test_entry_split_drops_large_words():
"Ensure entries drops words larger than specified max word length from compiled version."
# Arrange
entry_text = f"""*** Heading
\t\r
Body Line 1
"""
entry_text = f"""First Line
dog=1\n\r\t
cat=10
car=4
book=2
"""
entry = Entry(raw=entry_text, compiled=entry_text)
# Act
@ -87,11 +81,158 @@ def test_entry_split_drops_large_words():
processed_entry = TextToEntries.split_entries_by_max_tokens([entry], max_word_length=5)[0]
# Assert
# "Heading" dropped from compiled version because its over the set max word limit
assert len(processed_entry.compiled.split()) == len(entry_text.split()) - 1
# Ensure words larger than max word length are dropped
# Ensure newline characters are considered as word boundaries for splitting words. See #620
words_to_keep = ["First", "Line", "dog=1", "car=4"]
words_to_drop = ["cat=10", "book=2"]
assert all([word for word in words_to_keep if word in processed_entry.compiled])
assert not any([word for word in words_to_drop if word in processed_entry.compiled])
assert len(processed_entry.compiled.split()) == len(entry_text.split()) - 2
def test_entry_with_body_to_jsonl(tmp_path):
def test_parse_org_file_into_single_entry_if_small(tmp_path):
"Parse org file into single entry if it fits within the token limits."
# Arrange
original_entry = f"""
* Heading 1
body line 1
** Subheading 1.1
body line 1.1
"""
data = {
f"{tmp_path}": original_entry,
}
expected_entry = f"""
* Heading 1
body line 1
** Subheading 1.1
body line 1.1
""".lstrip()
# Act
# Extract Entries from specified Org files
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12)
for entry in extracted_entries:
entry.raw = clean(entry.raw)
# Assert
assert len(extracted_entries) == 1
assert entry.raw == expected_entry
def test_parse_org_entry_with_children_as_single_entry_if_small(tmp_path):
"Parse org entry with child headings as single entry only if it fits within the tokens limits."
# Arrange
entry = f"""
* Heading 1
body line 1
** Subheading 1.1
body line 1.1
* Heading 2
body line 2
** Subheading 2.1
longer body line 2.1
"""
data = {
f"{tmp_path}": entry,
}
first_expected_entry = f"""
* Path: {tmp_path}
** Heading 1.
body line 1
*** Subheading 1.1.
body line 1.1
""".lstrip()
second_expected_entry = f"""
* Path: {tmp_path}
** Heading 2.
body line 2
""".lstrip()
third_expected_entry = f"""
* Path: {tmp_path} / Heading 2
** Subheading 2.1.
longer body line 2.1
""".lstrip()
# Act
# Extract Entries from specified Org files
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12)
# Assert
assert len(extracted_entries) == 3
assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings"
assert extracted_entries[1].compiled == second_expected_entry, "Second entry does not include children headings"
assert extracted_entries[2].compiled == third_expected_entry, "Third entry is second entries child heading"
def test_separate_sibling_org_entries_if_all_cannot_fit_in_token_limit(tmp_path):
"Parse org sibling entries as separate entries only if it fits within the tokens limits."
# Arrange
entry = f"""
* Heading 1
body line 1
** Subheading 1.1
body line 1.1
* Heading 2
body line 2
** Subheading 2.1
body line 2.1
* Heading 3
body line 3
** Subheading 3.1
body line 3.1
"""
data = {
f"{tmp_path}": entry,
}
first_expected_entry = f"""
* Path: {tmp_path}
** Heading 1.
body line 1
*** Subheading 1.1.
body line 1.1
""".lstrip()
second_expected_entry = f"""
* Path: {tmp_path}
** Heading 2.
body line 2
*** Subheading 2.1.
body line 2.1
""".lstrip()
third_expected_entry = f"""
* Path: {tmp_path}
** Heading 3.
body line 3
*** Subheading 3.1.
body line 3.1
""".lstrip()
# Act
# Extract Entries from specified Org files
# Max tokens = 30 is in the middle of 2 entry (24 tokens) and 3 entry (36 tokens) tokens boundary
# Where each sibling entry contains 12 tokens per sibling entry * 3 entries = 36 tokens
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=30)
# Assert
assert len(extracted_entries) == 3
assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings"
assert extracted_entries[1].compiled == second_expected_entry, "Second entry includes children headings"
assert extracted_entries[2].compiled == third_expected_entry, "Third entry includes children headings"
def test_entry_with_body_to_entry(tmp_path):
"Ensure entries with valid body text are loaded."
# Arrange
entry = f"""*** Heading
@ -107,19 +248,13 @@ def test_entry_with_body_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Org files
entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data)
# Process Each Entry from All Notes Files
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
# Assert
assert len(jsonl_data) == 1
assert len(entries) == 1
def test_file_with_entry_after_intro_text_to_jsonl(tmp_path):
def test_file_with_entry_after_intro_text_to_entry(tmp_path):
"Ensure intro text before any headings is indexed."
# Arrange
entry = f"""
@ -134,18 +269,13 @@ Intro text
# Act
# Extract Entries from specified Org files
entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data)
# Process Each Entry from All Notes Files
entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
# Assert
assert len(jsonl_data) == 2
assert len(entries) == 2
def test_file_with_no_headings_to_jsonl(tmp_path):
def test_file_with_no_headings_to_entry(tmp_path):
"Ensure files with no heading, only body text are loaded."
# Arrange
entry = f"""
@ -158,15 +288,10 @@ def test_file_with_no_headings_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Org files
entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data)
# Process Each Entry from All Notes Files
entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
# Assert
assert len(jsonl_data) == 1
assert len(entries) == 1
def test_get_org_files(tmp_path):
@ -214,7 +339,8 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Arrange
entry = f"""
* Heading 1
** Heading 2
** Sub-Heading 1.1
* Heading 2
"""
data = {
f"{tmp_path}": entry,
@ -222,12 +348,14 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Act
# Extract Entries from specified Org files
entries, _ = OrgToEntries.extract_org_entries(org_files=data)
entries = OrgToEntries.extract_org_entries(org_files=data, index_heading_entries=True, max_tokens=3)
for entry in entries:
entry.raw = clean(f"{entry.raw}")
# Assert
assert len(entries) == 2
assert f"{entries[0]}".startswith("* Heading 1")
assert f"{entries[1]}".startswith("** Heading 2")
assert entries[0].raw == "* Heading 1\n** Sub-Heading 1.1\n", "Ensure entry includes heading ancestory"
assert entries[1].raw == "* Heading 2\n"
# Helper Functions
@ -237,3 +365,8 @@ def create_file(tmp_path, entry=None, filename="test.org"):
if entry:
org_file.write_text(entry)
return org_file
def clean(entry):
"Remove properties from entry for easier comparison."
return re.sub(r"\n:PROPERTIES:(.*?):END:", "", entry, flags=re.DOTALL)

View file

@ -1,4 +1,3 @@
import json
import os
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
@ -15,16 +14,10 @@ def test_single_page_pdf_to_jsonl():
pdf_bytes = f.read()
data = {"tests/data/pdf/singlepage.pdf": pdf_bytes}
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
# Process Each Entry from All Pdf Files
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(
PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
# Assert
assert len(jsonl_data) == 1
assert len(entries) == 1
def test_multi_page_pdf_to_jsonl():
@ -35,16 +28,10 @@ def test_multi_page_pdf_to_jsonl():
pdf_bytes = f.read()
data = {"tests/data/pdf/multipage.pdf": pdf_bytes}
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
# Process Each Entry from All Pdf Files
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(
PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
# Assert
assert len(jsonl_data) == 6
assert len(entries) == 6
def test_ocr_page_pdf_to_jsonl():
@ -55,10 +42,7 @@ def test_ocr_page_pdf_to_jsonl():
pdf_bytes = f.read()
data = {"tests/data/pdf/ocr_samples.pdf": pdf_bytes}
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
# Process Each Entry from All Pdf Files
entries = PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
assert len(entries) == 1
assert "playing on a strip of marsh" in entries[0].raw

View file

@ -1,4 +1,3 @@
import json
import os
from pathlib import Path
@ -11,10 +10,10 @@ from khoj.utils.rawconfig import TextContentConfig
def test_plaintext_file(tmp_path):
"Convert files with no heading to jsonl."
# Arrange
entry = f"""
raw_entry = f"""
Hi, I am a plaintext file and I have some plaintext words.
"""
plaintextfile = create_file(tmp_path, entry)
plaintextfile = create_file(tmp_path, raw_entry)
filename = plaintextfile.stem
@ -22,25 +21,21 @@ def test_plaintext_file(tmp_path):
# Extract Entries from specified plaintext files
data = {
f"{plaintextfile}": entry,
f"{plaintextfile}": raw_entry,
}
maps = PlaintextToEntries.convert_plaintext_entries_to_maps(entry_to_file_map=data)
entries = PlaintextToEntries.extract_plaintext_entries(entry_to_file_map=data)
# Convert each entry.file to absolute path to make them JSON serializable
for map in maps:
map.file = str(Path(map.file).absolute())
# Process Each Entry from All Notes Files
jsonl_string = PlaintextToEntries.convert_entries_to_jsonl(maps)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
for entry in entries:
entry.file = str(Path(entry.file).absolute())
# Assert
assert len(jsonl_data) == 1
assert len(entries) == 1
# Ensure raw entry with no headings do not get heading prefix prepended
assert not jsonl_data[0]["raw"].startswith("#")
assert not entries[0].raw.startswith("#")
# Ensure compiled entry has filename prepended as top level heading
assert jsonl_data[0]["compiled"] == f"{filename}\n{entry}"
assert entries[0].compiled == f"{filename}\n{raw_entry}"
def test_get_plaintext_files(tmp_path):
@ -98,11 +93,11 @@ def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
extracted_plaintext_files = get_plaintext_files(config=config)
# Act
maps = PlaintextToEntries.convert_plaintext_entries_to_maps(extracted_plaintext_files)
entries = PlaintextToEntries.extract_plaintext_entries(extracted_plaintext_files)
# Assert
assert len(maps) == 1
assert "<div>" not in maps[0].raw
assert len(entries) == 1
assert "<div>" not in entries[0].raw
# Helper Functions

View file

@ -57,18 +57,21 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_search_setup_with_empty_file_creates_no_entries(
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser
):
# Arrange
existing_entries = Entry.objects.filter(user=default_user).count()
data = get_org_files(org_config_with_only_new_file)
# Act
# Generate notes embeddings during asymmetric setup
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# Assert
assert "Deleted 8 entries. Created 0 new entries for user " in caplog.records[-1].message
updated_entries = Entry.objects.filter(user=default_user).count()
assert existing_entries == 2
assert updated_entries == 0
verify_embeddings(0, default_user)
@ -78,6 +81,7 @@ def test_text_indexer_deletes_embedding_before_regenerate(
content_config: ContentConfig, default_user: KhojUser, caplog
):
# Arrange
existing_entries = Entry.objects.filter(user=default_user).count()
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
@ -87,30 +91,18 @@ def test_text_indexer_deletes_embedding_before_regenerate(
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# Assert
updated_entries = Entry.objects.filter(user=default_user).count()
assert existing_entries == 2
assert updated_entries == 2
assert "Deleting all entries for file type org" in caplog.text
assert "Deleted 8 entries. Created 13 new entries for user " in caplog.records[-1].message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_search_setup_batch_processes(content_config: ContentConfig, default_user: KhojUser, caplog):
# Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
# Act
# Generate notes embeddings during asymmetric setup
with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# Assert
assert "Deleted 8 entries. Created 13 new entries for user " in caplog.records[-1].message
assert "Deleted 2 entries. Created 2 new entries for user " in caplog.records[-1].message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
# Arrange
existing_entries = Entry.objects.filter(user=default_user)
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
@ -127,6 +119,10 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
final_logs = caplog.text
# Assert
updated_entries = Entry.objects.filter(user=default_user)
for entry in updated_entries:
assert entry in existing_entries
assert len(existing_entries) == len(updated_entries)
assert "Deleting all entries for file type org" in initial_logs
assert "Deleting all entries for file type org" not in final_logs
@ -192,7 +188,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
# Assert
assert (
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
), "new entry not split by max tokens"
@ -250,16 +246,15 @@ conda activate khoj
# Assert
assert (
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
), "new entry not split by max tokens"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_regenerate_index_with_new_entry(
content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog
):
def test_regenerate_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser):
# Arrange
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
initial_data = get_org_files(org_config)
@ -271,28 +266,34 @@ def test_regenerate_index_with_new_entry(
final_data = get_org_files(org_config)
# Act
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
# regenerate notes jsonl, model embeddings and model to include entry from new file
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
final_logs = caplog.text
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
# Assert
assert "Deleted 8 entries. Created 13 new entries for user " in initial_logs
assert "Deleted 13 entries. Created 14 new entries for user " in final_logs
verify_embeddings(14, default_user)
for entry in updated_entries1:
assert entry in updated_entries2
assert not any([new_org_file.name in entry for entry in updated_entries1])
assert not any([new_org_file.name in entry for entry in existing_entries])
assert any([new_org_file.name in entry for entry in updated_entries2])
assert any(
["Saw a super cute video of a chihuahua doing the Tango on Youtube" in entry for entry in updated_entries2]
)
verify_embeddings(3, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_update_index_with_duplicate_entries_in_stable_order(
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser
):
# Arrange
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
# Insert org-mode entries with same compiled form into new org file
@ -304,30 +305,33 @@ def test_update_index_with_duplicate_entries_in_stable_order(
# Act
# generate embeddings, entries, notes model from scratch after adding new org-mode file
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
data = get_org_files(org_config_with_only_new_file)
# update embeddings, entries, notes model with no new changes
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
final_logs = caplog.text
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert "Deleted 8 entries. Created 1 new entries for user " in initial_logs
assert "Deleted 0 entries. Created 0 new entries for user " in final_logs
for entry in existing_entries:
assert entry not in updated_entries1
for entry in updated_entries1:
assert entry in updated_entries2
assert len(existing_entries) == 2
assert len(updated_entries1) == len(updated_entries2)
verify_embeddings(1, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser):
# Arrange
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
# Insert org-mode entries with same compiled form into new org file
@ -344,33 +348,34 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
# Act
# load embeddings, entries, notes model after adding new org file with 2 entries
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
final_logs = caplog.text
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert "Deleted 8 entries. Created 2 new entries for user " in initial_logs
assert "Deleted 1 entries. Created 0 new entries for user " in final_logs
for entry in existing_entries:
assert entry not in updated_entries1
# verify the entry in updated_entries2 is a subset of updated_entries1
for entry in updated_entries1:
assert entry not in updated_entries2
for entry in updated_entries2:
assert entry in updated_entries1[0]
verify_embeddings(1, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog):
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser):
# Arrange
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# append org-mode entry to first org input file in config
with open(new_org_file, "w") as f:
@ -381,14 +386,14 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
# Act
# update embeddings, entries with the newly added note
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
final_logs = caplog.text
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
# Assert
assert "Deleted 8 entries. Created 13 new entries for user " in initial_logs
assert "Deleted 0 entries. Created 1 new entries for user " in final_logs
verify_embeddings(14, default_user)
for entry in existing_entries:
assert entry not in updated_entries1
assert len(updated_entries1) == len(existing_entries) + 1
verify_embeddings(3, default_user)
# ----------------------------------------------------------------------------------------------------