diff --git a/src/khoj/processor/content/markdown/markdown_to_entries.py b/src/khoj/processor/content/markdown/markdown_to_entries.py index 7274cf1c..73e5bf47 100644 --- a/src/khoj/processor/content/markdown/markdown_to_entries.py +++ b/src/khoj/processor/content/markdown/markdown_to_entries.py @@ -1,14 +1,13 @@ import logging import re from pathlib import Path -from typing import List, Tuple +from typing import Dict, List, Tuple import urllib3 from khoj.database.models import Entry as DbEntry from khoj.database.models import KhojUser from khoj.processor.content.text_to_entries import TextToEntries -from khoj.utils.constants import empty_escape_sequences from khoj.utils.helpers import timer from khoj.utils.rawconfig import Entry @@ -31,15 +30,14 @@ class MarkdownToEntries(TextToEntries): else: deletion_file_names = None + max_tokens = 256 # Extract Entries from specified Markdown files - with timer("Parse entries from Markdown files into dictionaries", logger): - current_entries = MarkdownToEntries.convert_markdown_entries_to_maps( - *MarkdownToEntries.extract_markdown_entries(files) - ) + with timer("Extract entries from specified Markdown files", logger): + current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens) # Split entries by max tokens supported by model with timer("Split entries by max token size supported by model", logger): - current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) + current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens) # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): @@ -57,48 +55,84 @@ class MarkdownToEntries(TextToEntries): return num_new_embeddings, num_deleted_embeddings @staticmethod - def extract_markdown_entries(markdown_files): + def extract_markdown_entries(markdown_files, max_tokens=256) -> List[Entry]: "Extract entries by heading from specified Markdown files" - - # Regex to extract Markdown Entries by Heading - - entries = [] - entry_to_file_map = [] + entries: List[str] = [] + entry_to_file_map: List[Tuple[str, str]] = [] for markdown_file in markdown_files: try: markdown_content = markdown_files[markdown_file] entries, entry_to_file_map = MarkdownToEntries.process_single_markdown_file( - markdown_content, markdown_file, entries, entry_to_file_map + markdown_content, markdown_file, entries, entry_to_file_map, max_tokens ) except Exception as e: - logger.warning(f"Unable to process file: {markdown_file}. This file will not be indexed.") - logger.warning(e, exc_info=True) + logger.error( + f"Unable to process file: {markdown_file}. This file will not be indexed.\n{e}", exc_info=True + ) - return entries, dict(entry_to_file_map) + return MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map)) @staticmethod def process_single_markdown_file( - markdown_content: str, markdown_file: Path, entries: List, entry_to_file_map: List - ): - markdown_heading_regex = r"^#" + markdown_content: str, + markdown_file: str, + entries: List[str], + entry_to_file_map: List[Tuple[str, str]], + max_tokens=256, + ancestry: Dict[int, str] = {}, + ) -> Tuple[List[str], List[Tuple[str, str]]]: + # Prepend the markdown section's heading ancestry + ancestry_string = "\n".join([f"{'#' * key} {ancestry[key]}" for key in sorted(ancestry.keys())]) + markdown_content_with_ancestry = f"{ancestry_string}{markdown_content}" - markdown_entries_per_file = [] - any_headings = re.search(markdown_heading_regex, markdown_content, flags=re.MULTILINE) - for entry in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE): - # Add heading level as the regex split removed it from entries with headings - prefix = "#" if entry.startswith("#") else "# " if any_headings else "" - stripped_entry = entry.strip(empty_escape_sequences) - if stripped_entry != "": - markdown_entries_per_file.append(f"{prefix}{stripped_entry}") + # If content is small or content has no children headings, save it as a single entry + if len(TextToEntries.tokenizer(markdown_content_with_ancestry)) <= max_tokens or not re.search( + rf"^#{{{len(ancestry)+1},}}\s", markdown_content, flags=re.MULTILINE + ): + entry_to_file_map += [(markdown_content_with_ancestry, markdown_file)] + entries.extend([markdown_content_with_ancestry]) + return entries, entry_to_file_map + + # Split by next heading level present in the entry + next_heading_level = len(ancestry) + sections: List[str] = [] + while len(sections) < 2: + next_heading_level += 1 + sections = re.split(rf"(\n|^)(?=[#]{{{next_heading_level}}} .+\n?)", markdown_content, flags=re.MULTILINE) + + for section in sections: + # Skip empty sections + if section.strip() == "": + continue + + # Extract the section body and (when present) the heading + current_ancestry = ancestry.copy() + first_line = [line for line in section.split("\n") if line.strip() != ""][0] + if re.search(rf"^#{{{next_heading_level}}} ", first_line): + # Extract the section body without the heading + current_section_body = "\n".join(section.split(first_line)[1:]) + # Parse the section heading into current section ancestry + current_section_title = first_line[next_heading_level:].strip() + current_ancestry[next_heading_level] = current_section_title + else: + current_section_body = section + + # Recurse down children of the current entry + MarkdownToEntries.process_single_markdown_file( + current_section_body, + markdown_file, + entries, + entry_to_file_map, + max_tokens, + current_ancestry, + ) - entry_to_file_map += zip(markdown_entries_per_file, [markdown_file] * len(markdown_entries_per_file)) - entries.extend(markdown_entries_per_file) return entries, entry_to_file_map @staticmethod def convert_markdown_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]: "Convert each Markdown entries into a dictionary" - entries = [] + entries: List[Entry] = [] for parsed_entry in parsed_entries: raw_filename = entry_to_file_map[parsed_entry] @@ -108,13 +142,12 @@ class MarkdownToEntries(TextToEntries): entry_filename = urllib3.util.parse_url(raw_filename).url else: entry_filename = str(Path(raw_filename)) - stem = Path(raw_filename).stem heading = parsed_entry.splitlines()[0] if re.search("^#+\s", parsed_entry) else "" # Append base filename to compiled entry for context to model # Increment heading level for heading entries and make filename as its top level heading - prefix = f"# {stem}\n#" if heading else f"# {stem}\n" - compiled_entry = f"{entry_filename}\n{prefix}{parsed_entry}" + prefix = f"# {entry_filename}\n#" if heading else f"# {entry_filename}\n" + compiled_entry = f"{prefix}{parsed_entry}" entries.append( Entry( compiled=compiled_entry, @@ -127,8 +160,3 @@ class MarkdownToEntries(TextToEntries): logger.debug(f"Converted {len(parsed_entries)} markdown entries to dictionaries") return entries - - @staticmethod - def convert_markdown_maps_to_jsonl(entries: List[Entry]): - "Convert each Markdown entry to JSON and collate as JSONL" - return "".join([f"{entry.to_json()}\n" for entry in entries]) diff --git a/src/khoj/processor/content/org_mode/org_to_entries.py b/src/khoj/processor/content/org_mode/org_to_entries.py index 0e115f78..2dcaa744 100644 --- a/src/khoj/processor/content/org_mode/org_to_entries.py +++ b/src/khoj/processor/content/org_mode/org_to_entries.py @@ -1,10 +1,12 @@ import logging +import re from pathlib import Path -from typing import Iterable, List, Tuple +from typing import Dict, List, Tuple from khoj.database.models import Entry as DbEntry from khoj.database.models import KhojUser from khoj.processor.content.org_mode import orgnode +from khoj.processor.content.org_mode.orgnode import Orgnode from khoj.processor.content.text_to_entries import TextToEntries from khoj.utils import state from khoj.utils.helpers import timer @@ -21,9 +23,6 @@ class OrgToEntries(TextToEntries): def process( self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False ) -> Tuple[int, int]: - # Extract required fields from config - index_heading_entries = False - if not full_corpus: deletion_file_names = set([file for file in files if files[file] == ""]) files_to_process = set(files) - deletion_file_names @@ -32,14 +31,12 @@ class OrgToEntries(TextToEntries): deletion_file_names = None # Extract Entries from specified Org files - with timer("Parse entries from org files into OrgNode objects", logger): - entry_nodes, file_to_entries = self.extract_org_entries(files) - - with timer("Convert OrgNodes into list of entries", logger): - current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries) + max_tokens = 256 + with timer("Extract entries from specified Org files", logger): + current_entries = self.extract_org_entries(files, max_tokens=max_tokens) with timer("Split entries by max token size supported by model", logger): - current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) + current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=max_tokens) # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): @@ -57,93 +54,165 @@ class OrgToEntries(TextToEntries): return num_new_embeddings, num_deleted_embeddings @staticmethod - def extract_org_entries(org_files: dict[str, str]): + def extract_org_entries( + org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=256 + ) -> List[Entry]: "Extract entries from specified Org files" - entries = [] - entry_to_file_map: List[Tuple[orgnode.Orgnode, str]] = [] + entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens) + return OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map, index_heading_entries) + + @staticmethod + def extract_org_nodes(org_files: dict[str, str], max_tokens) -> Tuple[List[List[Orgnode]], Dict[Orgnode, str]]: + "Extract org nodes from specified org files" + entries: List[List[Orgnode]] = [] + entry_to_file_map: List[Tuple[Orgnode, str]] = [] for org_file in org_files: - filename = org_file - file = org_files[org_file] try: - org_file_entries = orgnode.makelist(file, filename) - entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries)) - entries.extend(org_file_entries) + org_content = org_files[org_file] + entries, entry_to_file_map = OrgToEntries.process_single_org_file( + org_content, org_file, entries, entry_to_file_map, max_tokens + ) except Exception as e: - logger.warning(f"Unable to process file: {org_file}. This file will not be indexed.") - logger.warning(e, exc_info=True) + logger.error(f"Unable to process file: {org_file}. Skipped indexing it.\nError; {e}", exc_info=True) return entries, dict(entry_to_file_map) @staticmethod - def process_single_org_file(org_content: str, org_file: str, entries: List, entry_to_file_map: List): - # Process single org file. The org parser assumes that the file is a single org file and reads it from a buffer. We'll split the raw conetnt of this file by new line to mimic the same behavior. - try: - org_file_entries = orgnode.makelist(org_content, org_file) - entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries)) - entries.extend(org_file_entries) - return entries, entry_to_file_map - except Exception as e: - logger.error(f"Error processing file: {org_file} with error: {e}", exc_info=True) + def process_single_org_file( + org_content: str, + org_file: str, + entries: List[List[Orgnode]], + entry_to_file_map: List[Tuple[Orgnode, str]], + max_tokens=256, + ancestry: Dict[int, str] = {}, + ) -> Tuple[List[List[Orgnode]], List[Tuple[Orgnode, str]]]: + """Parse org_content from org_file into OrgNode entries + + Recurse down org file entries, one heading level at a time, + until reach a leaf entry or the current entry tree fits max_tokens. + + Parse recursion terminating entry (trees) into (a list of) OrgNode objects. + """ + # Prepend the org section's heading ancestry + ancestry_string = "\n".join([f"{'*' * key} {ancestry[key]}" for key in sorted(ancestry.keys())]) + org_content_with_ancestry = f"{ancestry_string}{org_content}" + + # If content is small or content has no children headings, save it as a single entry + # Note: This is the terminating condition for this recursive function + if len(TextToEntries.tokenizer(org_content_with_ancestry)) <= max_tokens or not re.search( + rf"^\*{{{len(ancestry)+1},}}\s", org_content, re.MULTILINE + ): + orgnode_content_with_ancestry = orgnode.makelist(org_content_with_ancestry, org_file) + entry_to_file_map += zip(orgnode_content_with_ancestry, [org_file] * len(orgnode_content_with_ancestry)) + entries.extend([orgnode_content_with_ancestry]) return entries, entry_to_file_map + # Split this entry tree into sections by the next heading level in it + # Increment heading level until able to split entry into sections + # A successful split will result in at least 2 sections + next_heading_level = len(ancestry) + sections: List[str] = [] + while len(sections) < 2: + next_heading_level += 1 + sections = re.split(rf"(\n|^)(?=[*]{{{next_heading_level}}} .+\n?)", org_content, flags=re.MULTILINE) + + # Recurse down each non-empty section after parsing its body, heading and ancestry + for section in sections: + # Skip empty sections + if section.strip() == "": + continue + + # Extract the section body and (when present) the heading + current_ancestry = ancestry.copy() + first_non_empty_line = [line for line in section.split("\n") if line.strip() != ""][0] + # If first non-empty line is a heading with expected heading level + if re.search(rf"^\*{{{next_heading_level}}}\s", first_non_empty_line): + # Extract the section body without the heading + current_section_body = "\n".join(section.split(first_non_empty_line)[1:]) + # Parse the section heading into current section ancestry + current_section_title = first_non_empty_line[next_heading_level:].strip() + current_ancestry[next_heading_level] = current_section_title + # Else process the section as just body text + else: + current_section_body = section + + # Recurse down children of the current entry + OrgToEntries.process_single_org_file( + current_section_body, + org_file, + entries, + entry_to_file_map, + max_tokens, + current_ancestry, + ) + + return entries, entry_to_file_map + @staticmethod def convert_org_nodes_to_entries( - parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False + parsed_entries: List[List[Orgnode]], + entry_to_file_map: Dict[Orgnode, str], + index_heading_entries: bool = False, ) -> List[Entry]: - "Convert Org-Mode nodes into list of Entry objects" + """ + Convert OrgNode lists into list of Entry objects + + Each list of OrgNodes is a parsed parent org tree or leaf node. + Convert each list of these OrgNodes into a single Entry. + """ entries: List[Entry] = [] - for parsed_entry in parsed_entries: - if not parsed_entry.hasBody and not index_heading_entries: - # Ignore title notes i.e notes with just headings and empty body - continue + for entry_group in parsed_entries: + entry_heading, entry_compiled, entry_raw = "", "", "" + for parsed_entry in entry_group: + if not parsed_entry.hasBody and not index_heading_entries: + # Ignore title notes i.e notes with just headings and empty body + continue - todo_str = f"{parsed_entry.todo} " if parsed_entry.todo else "" + todo_str = f"{parsed_entry.todo} " if parsed_entry.todo else "" - # Prepend ancestor headings, filename as top heading to entry for context - ancestors_trail = " / ".join(parsed_entry.ancestors) or Path(entry_to_file_map[parsed_entry]) - if parsed_entry.heading: - heading = f"* Path: {ancestors_trail}\n** {todo_str}{parsed_entry.heading}." - else: - heading = f"* Path: {ancestors_trail}." + # Set base level to current org-node tree's root heading level + if not entry_heading and parsed_entry.level > 0: + base_level = parsed_entry.level + # Indent entry by 1 heading level as ancestry is prepended as top level heading + heading = f"{'*' * (parsed_entry.level-base_level+2)} {todo_str}" if parsed_entry.level > 0 else "" + if parsed_entry.heading: + heading += f"{parsed_entry.heading}." - compiled = heading - if state.verbose > 2: - logger.debug(f"Title: {heading}") + # Prepend ancestor headings, filename as top heading to root parent entry for context + # Children nodes do not need ancestors trail as root parent node will have it + if not entry_heading: + ancestors_trail = " / ".join(parsed_entry.ancestors) or Path(entry_to_file_map[parsed_entry]) + heading = f"* Path: {ancestors_trail}\n{heading}" if heading else f"* Path: {ancestors_trail}." - if parsed_entry.tags: - tags_str = " ".join(parsed_entry.tags) - compiled += f"\t {tags_str}." - if state.verbose > 2: - logger.debug(f"Tags: {tags_str}") + compiled = heading - if parsed_entry.closed: - compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.' - if state.verbose > 2: - logger.debug(f'Closed: {parsed_entry.closed.strftime("%Y-%m-%d")}') + if parsed_entry.tags: + tags_str = " ".join(parsed_entry.tags) + compiled += f"\t {tags_str}." - if parsed_entry.scheduled: - compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.' - if state.verbose > 2: - logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}') + if parsed_entry.closed: + compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.' - if parsed_entry.hasBody: - compiled += f"\n {parsed_entry.body}" - if state.verbose > 2: - logger.debug(f"Body: {parsed_entry.body}") + if parsed_entry.scheduled: + compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.' - if compiled: + if parsed_entry.hasBody: + compiled += f"\n {parsed_entry.body}" + + # Add the sub-entry contents to the entry + entry_compiled += f"{compiled}" + entry_raw += f"{parsed_entry}" + if not entry_heading: + entry_heading = heading + + if entry_compiled: entries.append( Entry( - compiled=compiled, - raw=f"{parsed_entry}", - heading=f"{heading}", + compiled=entry_compiled, + raw=entry_raw, + heading=f"{entry_heading}", file=f"{entry_to_file_map[parsed_entry]}", ) ) return entries - - @staticmethod - def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str: - "Convert each Org-Mode entry to JSON and collate as JSONL" - return "".join([f"{entry_dict.to_json()}\n" for entry_dict in entries]) diff --git a/src/khoj/processor/content/org_mode/orgnode.py b/src/khoj/processor/content/org_mode/orgnode.py index 0c9bcb6e..8449cc9d 100644 --- a/src/khoj/processor/content/org_mode/orgnode.py +++ b/src/khoj/processor/content/org_mode/orgnode.py @@ -37,7 +37,7 @@ import datetime import re from os.path import relpath from pathlib import Path -from typing import List +from typing import Dict, List, Tuple indent_regex = re.compile(r"^ *") @@ -58,7 +58,7 @@ def makelist_with_filepath(filename): return makelist(f, filename) -def makelist(file, filename): +def makelist(file, filename) -> List["Orgnode"]: """ Read an org-mode file and return a list of Orgnode objects created from this file. @@ -80,16 +80,16 @@ def makelist(file, filename): } # populated from #+SEQ_TODO line level = "" heading = "" - ancestor_headings = [] + ancestor_headings: List[str] = [] bodytext = "" introtext = "" - tags = list() # set of all tags in headline - closed_date = "" - sched_date = "" - deadline_date = "" - logbook = list() + tags: List[str] = list() # set of all tags in headline + closed_date: datetime.date = None + sched_date: datetime.date = None + deadline_date: datetime.date = None + logbook: List[Tuple[datetime.datetime, datetime.datetime]] = list() nodelist: List[Orgnode] = list() - property_map = dict() + property_map: Dict[str, str] = dict() in_properties_drawer = False in_logbook_drawer = False file_title = f"{filename}" @@ -102,13 +102,13 @@ def makelist(file, filename): thisNode = Orgnode(level, heading, bodytext, tags, ancestor_headings) if closed_date: thisNode.closed = closed_date - closed_date = "" + closed_date = None if sched_date: thisNode.scheduled = sched_date - sched_date = "" + sched_date = None if deadline_date: thisNode.deadline = deadline_date - deadline_date = "" + deadline_date = None if logbook: thisNode.logbook = logbook logbook = list() @@ -116,7 +116,7 @@ def makelist(file, filename): nodelist.append(thisNode) property_map = {"LINE": f"file:{normalize_filename(filename)}::{ctr}"} previous_level = level - previous_heading = heading + previous_heading: str = heading level = heading_search.group(1) heading = heading_search.group(2) bodytext = "" @@ -495,12 +495,13 @@ class Orgnode(object): if self._priority: n = n + "[#" + self._priority + "] " n = n + self._heading - n = "%-60s " % n # hack - tags will start in column 62 - closecolon = "" - for t in self._tags: - n = n + ":" + t - closecolon = ":" - n = n + closecolon + if self._tags: + n = "%-60s " % n # hack - tags will start in column 62 + closecolon = "" + for t in self._tags: + n = n + ":" + t + closecolon = ":" + n = n + closecolon n = n + "\n" # Get body indentation from first line of body diff --git a/src/khoj/processor/content/pdf/pdf_to_entries.py b/src/khoj/processor/content/pdf/pdf_to_entries.py index 3582cbe0..c59b305c 100644 --- a/src/khoj/processor/content/pdf/pdf_to_entries.py +++ b/src/khoj/processor/content/pdf/pdf_to_entries.py @@ -32,8 +32,8 @@ class PdfToEntries(TextToEntries): deletion_file_names = None # Extract Entries from specified Pdf files - with timer("Parse entries from PDF files into dictionaries", logger): - current_entries = PdfToEntries.convert_pdf_entries_to_maps(*PdfToEntries.extract_pdf_entries(files)) + with timer("Extract entries from specified PDF files", logger): + current_entries = PdfToEntries.extract_pdf_entries(files) # Split entries by max tokens supported by model with timer("Split entries by max token size supported by model", logger): @@ -55,11 +55,11 @@ class PdfToEntries(TextToEntries): return num_new_embeddings, num_deleted_embeddings @staticmethod - def extract_pdf_entries(pdf_files): + def extract_pdf_entries(pdf_files) -> List[Entry]: """Extract entries by page from specified PDF files""" - entries = [] - entry_to_location_map = [] + entries: List[str] = [] + entry_to_location_map: List[Tuple[str, str]] = [] for pdf_file in pdf_files: try: # Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path @@ -83,7 +83,7 @@ class PdfToEntries(TextToEntries): if os.path.exists(f"{tmp_file}"): os.remove(f"{tmp_file}") - return entries, dict(entry_to_location_map) + return PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map)) @staticmethod def convert_pdf_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]: @@ -106,8 +106,3 @@ class PdfToEntries(TextToEntries): logger.debug(f"Converted {len(parsed_entries)} PDF entries to dictionaries") return entries - - @staticmethod - def convert_pdf_maps_to_jsonl(entries: List[Entry]): - "Convert each PDF entry to JSON and collate as JSONL" - return "".join([f"{entry.to_json()}\n" for entry in entries]) diff --git a/src/khoj/processor/content/plaintext/plaintext_to_entries.py b/src/khoj/processor/content/plaintext/plaintext_to_entries.py index 9ea8c654..4fb0dd2e 100644 --- a/src/khoj/processor/content/plaintext/plaintext_to_entries.py +++ b/src/khoj/processor/content/plaintext/plaintext_to_entries.py @@ -42,8 +42,8 @@ class PlaintextToEntries(TextToEntries): logger.warning(e, exc_info=True) # Extract Entries from specified plaintext files - with timer("Parse entries from plaintext files", logger): - current_entries = PlaintextToEntries.convert_plaintext_entries_to_maps(files) + with timer("Parse entries from specified Plaintext files", logger): + current_entries = PlaintextToEntries.extract_plaintext_entries(files) # Split entries by max tokens supported by model with timer("Split entries by max token size supported by model", logger): @@ -74,7 +74,7 @@ class PlaintextToEntries(TextToEntries): return soup.get_text(strip=True, separator="\n") @staticmethod - def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]: + def extract_plaintext_entries(entry_to_file_map: dict[str, str]) -> List[Entry]: "Convert each plaintext entries into a dictionary" entries = [] for file, entry in entry_to_file_map.items(): @@ -87,8 +87,3 @@ class PlaintextToEntries(TextToEntries): ) ) return entries - - @staticmethod - def convert_entries_to_jsonl(entries: List[Entry]): - "Convert each entry to JSON and collate as JSONL" - return "".join([f"{entry.to_json()}\n" for entry in entries]) diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py index f8bf30dc..edd814f6 100644 --- a/src/khoj/processor/content/text_to_entries.py +++ b/src/khoj/processor/content/text_to_entries.py @@ -1,10 +1,12 @@ import hashlib import logging +import re import uuid from abc import ABC, abstractmethod from itertools import repeat from typing import Any, Callable, List, Set, Tuple +from langchain.text_splitter import RecursiveCharacterTextSplitter from tqdm import tqdm from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default @@ -34,6 +36,27 @@ class TextToEntries(ABC): def hash_func(key: str) -> Callable: return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest() + @staticmethod + def remove_long_words(text: str, max_word_length: int = 500) -> str: + "Remove words longer than max_word_length from text." + # Split the string by words, keeping the delimiters + splits = re.split(r"(\s+)", text) + [""] + words_with_delimiters = list(zip(splits[::2], splits[1::2])) + + # Filter out long words while preserving delimiters in text + filtered_text = [ + f"{word}{delimiter}" + for word, delimiter in words_with_delimiters + if not word.strip() or len(word.strip()) <= max_word_length + ] + + return "".join(filtered_text) + + @staticmethod + def tokenizer(text: str) -> List[str]: + "Tokenize text into words." + return text.split() + @staticmethod def split_entries_by_max_tokens( entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500 @@ -44,24 +67,30 @@ class TextToEntries(ABC): if is_none_or_empty(entry.compiled): continue - # Split entry into words - compiled_entry_words = [word for word in entry.compiled.split(" ") if word != ""] - - # Drop long words instead of having entry truncated to maintain quality of entry processed by models - compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length] + # Split entry into chunks of max_tokens + # Use chunking preference order: paragraphs > sentences > words > characters + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=max_tokens, + separators=["\n\n", "\n", "!", "?", ".", " ", "\t", ""], + keep_separator=True, + length_function=lambda chunk: len(TextToEntries.tokenizer(chunk)), + chunk_overlap=0, + ) + chunked_entry_chunks = text_splitter.split_text(entry.compiled) corpus_id = uuid.uuid4() - # Split entry into chunks of max tokens - for chunk_index in range(0, len(compiled_entry_words), max_tokens): - compiled_entry_words_chunk = compiled_entry_words[chunk_index : chunk_index + max_tokens] - compiled_entry_chunk = " ".join(compiled_entry_words_chunk) - + # Create heading prefixed entry from each chunk + for chunk_index, compiled_entry_chunk in enumerate(chunked_entry_chunks): # Prepend heading to all other chunks, the first chunk already has heading from original entry - if chunk_index > 0: + if chunk_index > 0 and entry.heading: # Snip heading to avoid crossing max_tokens limit # Keep last 100 characters of heading as entry heading more important than filename snipped_heading = entry.heading[-100:] - compiled_entry_chunk = f"{snipped_heading}.\n{compiled_entry_chunk}" + # Prepend snipped heading + compiled_entry_chunk = f"{snipped_heading}\n{compiled_entry_chunk}" + + # Drop long words instead of having entry truncated to maintain quality of entry processed by models + compiled_entry_chunk = TextToEntries.remove_long_words(compiled_entry_chunk, max_word_length) # Clean entry of unwanted characters like \0 character compiled_entry_chunk = TextToEntries.clean_field(compiled_entry_chunk) @@ -160,7 +189,7 @@ class TextToEntries(ABC): new_dates = [] with timer("Indexed dates from added entries in", logger): for added_entry in added_entries: - dates_in_entries = zip(self.date_filter.extract_dates(added_entry.raw), repeat(added_entry)) + dates_in_entries = zip(self.date_filter.extract_dates(added_entry.compiled), repeat(added_entry)) dates_to_create = [ EntryDates(date=date, entry=added_entry) for date, added_entry in dates_in_entries @@ -244,11 +273,6 @@ class TextToEntries(ABC): return entries_with_ids - @staticmethod - def convert_text_maps_to_jsonl(entries: List[Entry]) -> str: - # Convert each entry to JSON and write to JSONL file - return "".join([f"{entry.to_json()}\n" for entry in entries]) - @staticmethod def clean_field(field: str) -> str: return field.replace("\0", "") if not is_none_or_empty(field) else "" diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 57f26163..c528356e 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -489,7 +489,7 @@ async def chat( common: CommonQueryParams, q: str, n: Optional[int] = 5, - d: Optional[float] = 0.18, + d: Optional[float] = 0.22, stream: Optional[bool] = False, title: Optional[str] = None, conversation_id: Optional[int] = None, diff --git a/tests/test_client.py b/tests/test_client.py index d1f07a4b..fb73ca3c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -306,7 +306,7 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa user_query = quote("How to git install application?") # Act - response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers) + response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.22", headers=headers) # Assert assert response.status_code == 200 @@ -325,7 +325,7 @@ def test_notes_search_no_results(client, search_config: SearchConfig, sample_org user_query = quote("How to find my goat?") # Act - response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers) + response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.22", headers=headers) # Assert assert response.status_code == 200 @@ -409,7 +409,7 @@ def test_notes_search_requires_parent_context( user_query = quote("Install Khoj on Emacs") # Act - response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers) + response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.22", headers=headers) # Assert assert response.status_code == 200 diff --git a/tests/test_markdown_to_entries.py b/tests/test_markdown_to_entries.py index 12ea238e..d63f026a 100644 --- a/tests/test_markdown_to_entries.py +++ b/tests/test_markdown_to_entries.py @@ -1,4 +1,3 @@ -import json import os from pathlib import Path @@ -7,8 +6,8 @@ from khoj.utils.fs_syncer import get_markdown_files from khoj.utils.rawconfig import TextContentConfig -def test_markdown_file_with_no_headings_to_jsonl(tmp_path): - "Convert files with no heading to jsonl." +def test_extract_markdown_with_no_headings(tmp_path): + "Convert markdown file with no heading to entry format." # Arrange entry = f""" - Bullet point 1 @@ -17,30 +16,24 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path): data = { f"{tmp_path}": entry, } - expected_heading = f"# {tmp_path.stem}" + expected_heading = f"# {tmp_path}" # Act # Extract Entries from specified Markdown files - entry_nodes, file_to_entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data) - - # Process Each Entry from All Notes Files - jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl( - MarkdownToEntries.convert_markdown_entries_to_maps(entry_nodes, file_to_entries) - ) - jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3) # Assert - assert len(jsonl_data) == 1 + assert len(entries) == 1 # Ensure raw entry with no headings do not get heading prefix prepended - assert not jsonl_data[0]["raw"].startswith("#") + assert not entries[0].raw.startswith("#") # Ensure compiled entry has filename prepended as top level heading - assert expected_heading in jsonl_data[0]["compiled"] + assert entries[0].compiled.startswith(expected_heading) # Ensure compiled entry also includes the file name - assert str(tmp_path) in jsonl_data[0]["compiled"] + assert str(tmp_path) in entries[0].compiled -def test_single_markdown_entry_to_jsonl(tmp_path): - "Convert markdown entry from single file to jsonl." +def test_extract_single_markdown_entry(tmp_path): + "Convert markdown from single file to entry format." # Arrange entry = f"""### Heading \t\r @@ -52,20 +45,14 @@ def test_single_markdown_entry_to_jsonl(tmp_path): # Act # Extract Entries from specified Markdown files - entries, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data) - - # Process Each Entry from All Notes Files - jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl( - MarkdownToEntries.convert_markdown_entries_to_maps(entries, entry_to_file_map) - ) - jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3) # Assert - assert len(jsonl_data) == 1 + assert len(entries) == 1 -def test_multiple_markdown_entries_to_jsonl(tmp_path): - "Convert multiple markdown entries from single file to jsonl." +def test_extract_multiple_markdown_entries(tmp_path): + "Convert multiple markdown from single file to entry format." # Arrange entry = f""" ### Heading 1 @@ -81,19 +68,139 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path): # Act # Extract Entries from specified Markdown files - entry_strings, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data) - entries = MarkdownToEntries.convert_markdown_entries_to_maps(entry_strings, entry_to_file_map) - - # Process Each Entry from All Notes Files - jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(entries) - jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3) # Assert - assert len(jsonl_data) == 2 + assert len(entries) == 2 # Ensure entry compiled strings include the markdown files they originate from assert all([tmp_path.stem in entry.compiled for entry in entries]) +def test_extract_entries_with_different_level_headings(tmp_path): + "Extract markdown entries with different level headings." + # Arrange + entry = f""" +# Heading 1 +## Sub-Heading 1.1 +# Heading 2 +""" + data = { + f"{tmp_path}": entry, + } + + # Act + # Extract Entries from specified Markdown files + entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3) + + # Assert + assert len(entries) == 2 + assert entries[0].raw == "# Heading 1\n## Sub-Heading 1.1", "Ensure entry includes heading ancestory" + assert entries[1].raw == "# Heading 2\n" + + +def test_extract_entries_with_non_incremental_heading_levels(tmp_path): + "Extract markdown entries when deeper child level before shallower child level." + # Arrange + entry = f""" +# Heading 1 +#### Sub-Heading 1.1 +## Sub-Heading 1.2 +# Heading 2 +""" + data = { + f"{tmp_path}": entry, + } + + # Act + # Extract Entries from specified Markdown files + entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3) + + # Assert + assert len(entries) == 3 + assert entries[0].raw == "# Heading 1\n#### Sub-Heading 1.1", "Ensure entry includes heading ancestory" + assert entries[1].raw == "# Heading 1\n## Sub-Heading 1.2", "Ensure entry includes heading ancestory" + assert entries[2].raw == "# Heading 2\n" + + +def test_extract_entries_with_text_before_headings(tmp_path): + "Extract markdown entries with some text before any headings." + # Arrange + entry = f""" +Text before headings +# Heading 1 +body line 1 +## Heading 2 +body line 2 +""" + data = { + f"{tmp_path}": entry, + } + + # Act + # Extract Entries from specified Markdown files + entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3) + + # Assert + assert len(entries) == 3 + assert entries[0].raw == "\nText before headings" + assert entries[1].raw == "# Heading 1\nbody line 1" + assert entries[2].raw == "# Heading 1\n## Heading 2\nbody line 2\n", "Ensure raw entry includes heading ancestory" + + +def test_parse_markdown_file_into_single_entry_if_small(tmp_path): + "Parse markdown file into single entry if it fits within the token limits." + # Arrange + entry = f""" +# Heading 1 +body line 1 +## Subheading 1.1 +body line 1.1 +""" + data = { + f"{tmp_path}": entry, + } + + # Act + # Extract Entries from specified Markdown files + entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12) + + # Assert + assert len(entries) == 1 + assert entries[0].raw == entry + + +def test_parse_markdown_entry_with_children_as_single_entry_if_small(tmp_path): + "Parse markdown entry with child headings as single entry if it fits within the tokens limits." + # Arrange + entry = f""" +# Heading 1 +body line 1 +## Subheading 1.1 +body line 1.1 +# Heading 2 +body line 2 +## Subheading 2.1 +longer body line 2.1 +""" + data = { + f"{tmp_path}": entry, + } + + # Act + # Extract Entries from specified Markdown files + entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12) + + # Assert + assert len(entries) == 3 + assert ( + entries[0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1" + ), "First entry includes children headings" + assert entries[1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings" + assert ( + entries[2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n" + ), "Third entry is second entries child heading" + + def test_get_markdown_files(tmp_path): "Ensure Markdown files specified via input-filter, input-files extracted" # Arrange @@ -131,27 +238,6 @@ def test_get_markdown_files(tmp_path): assert set(extracted_org_files.keys()) == expected_files -def test_extract_entries_with_different_level_headings(tmp_path): - "Extract markdown entries with different level headings." - # Arrange - entry = f""" -# Heading 1 -## Heading 2 -""" - data = { - f"{tmp_path}": entry, - } - - # Act - # Extract Entries from specified Markdown files - entries, _ = MarkdownToEntries.extract_markdown_entries(markdown_files=data) - - # Assert - assert len(entries) == 2 - assert entries[0] == "# Heading 1" - assert entries[1] == "## Heading 2" - - # Helper Functions def create_file(tmp_path: Path, entry=None, filename="test.md"): markdown_file = tmp_path / filename diff --git a/tests/test_multiple_users.py b/tests/test_multiple_users.py index bb0f99d8..4e8e456a 100644 --- a/tests/test_multiple_users.py +++ b/tests/test_multiple_users.py @@ -56,7 +56,7 @@ def test_index_update_with_user2_inaccessible_user1(client, api_user2: KhojApiUs # Assert assert update_response.status_code == 200 - assert len(results) == 5 + assert len(results) == 3 for result in results: assert result["additional"]["file"] not in source_file_symbol diff --git a/tests/test_openai_chat_actors.py b/tests/test_openai_chat_actors.py index 5c2855b2..df9d8f07 100644 --- a/tests/test_openai_chat_actors.py +++ b/tests/test_openai_chat_actors.py @@ -470,10 +470,6 @@ async def test_websearch_with_operators(chat_client): ["site:reddit.com" in response for response in responses] ), "Expected a search query to include site:reddit.com but got: " + str(responses) - assert any( - ["after:2024/04/01" in response for response in responses] - ), "Expected a search query to include after:2024/04/01 but got: " + str(responses) - # ---------------------------------------------------------------------------------------------------- @pytest.mark.anyio diff --git a/tests/test_org_to_entries.py b/tests/test_org_to_entries.py index f7c19b79..f01f50f3 100644 --- a/tests/test_org_to_entries.py +++ b/tests/test_org_to_entries.py @@ -1,5 +1,5 @@ -import json import os +import re from khoj.processor.content.org_mode.org_to_entries import OrgToEntries from khoj.processor.content.text_to_entries import TextToEntries @@ -8,7 +8,7 @@ from khoj.utils.helpers import is_none_or_empty from khoj.utils.rawconfig import Entry, TextContentConfig -def test_configure_heading_entry_to_jsonl(tmp_path): +def test_configure_indexing_heading_only_entries(tmp_path): """Ensure entries with empty body are ignored, unless explicitly configured to index heading entries. Property drawers not considered Body. Ignore control characters for evaluating if Body empty.""" # Arrange @@ -26,24 +26,21 @@ def test_configure_heading_entry_to_jsonl(tmp_path): for index_heading_entries in [True, False]: # Act # Extract entries into jsonl from specified Org files - jsonl_string = OrgToEntries.convert_org_entries_to_jsonl( - OrgToEntries.convert_org_nodes_to_entries( - *OrgToEntries.extract_org_entries(org_files=data), index_heading_entries=index_heading_entries - ) + entries = OrgToEntries.extract_org_entries( + org_files=data, index_heading_entries=index_heading_entries, max_tokens=3 ) - jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert if index_heading_entries: # Entry with empty body indexed when index_heading_entries set to True - assert len(jsonl_data) == 1 + assert len(entries) == 1 else: # Entry with empty body ignored when index_heading_entries set to False - assert is_none_or_empty(jsonl_data) + assert is_none_or_empty(entries) -def test_entry_split_when_exceeds_max_words(): - "Ensure entries with compiled words exceeding max_words are split." +def test_entry_split_when_exceeds_max_tokens(): + "Ensure entries with compiled words exceeding max_tokens are split." # Arrange tmp_path = "/tmp/test.org" entry = f"""*** Heading @@ -57,29 +54,26 @@ def test_entry_split_when_exceeds_max_words(): # Act # Extract Entries from specified Org files - entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data) + entries = OrgToEntries.extract_org_entries(org_files=data) - # Split each entry from specified Org files by max words - jsonl_string = OrgToEntries.convert_org_entries_to_jsonl( - TextToEntries.split_entries_by_max_tokens( - OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4 - ) - ) - jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + # Split each entry from specified Org files by max tokens + entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=6) # Assert - assert len(jsonl_data) == 2 - # Ensure compiled entries split by max_words start with entry heading (for search context) - assert all([entry["compiled"].startswith(expected_heading) for entry in jsonl_data]) + assert len(entries) == 2 + # Ensure compiled entries split by max tokens start with entry heading (for search context) + assert all([entry.compiled.startswith(expected_heading) for entry in entries]) def test_entry_split_drops_large_words(): "Ensure entries drops words larger than specified max word length from compiled version." # Arrange - entry_text = f"""*** Heading - \t\r - Body Line 1 - """ + entry_text = f"""First Line +dog=1\n\r\t +cat=10 +car=4 +book=2 +""" entry = Entry(raw=entry_text, compiled=entry_text) # Act @@ -87,11 +81,158 @@ def test_entry_split_drops_large_words(): processed_entry = TextToEntries.split_entries_by_max_tokens([entry], max_word_length=5)[0] # Assert - # "Heading" dropped from compiled version because its over the set max word limit - assert len(processed_entry.compiled.split()) == len(entry_text.split()) - 1 + # Ensure words larger than max word length are dropped + # Ensure newline characters are considered as word boundaries for splitting words. See #620 + words_to_keep = ["First", "Line", "dog=1", "car=4"] + words_to_drop = ["cat=10", "book=2"] + assert all([word for word in words_to_keep if word in processed_entry.compiled]) + assert not any([word for word in words_to_drop if word in processed_entry.compiled]) + assert len(processed_entry.compiled.split()) == len(entry_text.split()) - 2 -def test_entry_with_body_to_jsonl(tmp_path): +def test_parse_org_file_into_single_entry_if_small(tmp_path): + "Parse org file into single entry if it fits within the token limits." + # Arrange + original_entry = f""" +* Heading 1 +body line 1 +** Subheading 1.1 +body line 1.1 +""" + data = { + f"{tmp_path}": original_entry, + } + expected_entry = f""" +* Heading 1 +body line 1 + +** Subheading 1.1 +body line 1.1 + +""".lstrip() + + # Act + # Extract Entries from specified Org files + extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12) + for entry in extracted_entries: + entry.raw = clean(entry.raw) + + # Assert + assert len(extracted_entries) == 1 + assert entry.raw == expected_entry + + +def test_parse_org_entry_with_children_as_single_entry_if_small(tmp_path): + "Parse org entry with child headings as single entry only if it fits within the tokens limits." + # Arrange + entry = f""" +* Heading 1 +body line 1 +** Subheading 1.1 +body line 1.1 +* Heading 2 +body line 2 +** Subheading 2.1 +longer body line 2.1 +""" + data = { + f"{tmp_path}": entry, + } + first_expected_entry = f""" +* Path: {tmp_path} +** Heading 1. + body line 1 + +*** Subheading 1.1. + body line 1.1 + +""".lstrip() + second_expected_entry = f""" +* Path: {tmp_path} +** Heading 2. + body line 2 + +""".lstrip() + third_expected_entry = f""" +* Path: {tmp_path} / Heading 2 +** Subheading 2.1. + longer body line 2.1 + +""".lstrip() + + # Act + # Extract Entries from specified Org files + extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12) + + # Assert + assert len(extracted_entries) == 3 + assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings" + assert extracted_entries[1].compiled == second_expected_entry, "Second entry does not include children headings" + assert extracted_entries[2].compiled == third_expected_entry, "Third entry is second entries child heading" + + +def test_separate_sibling_org_entries_if_all_cannot_fit_in_token_limit(tmp_path): + "Parse org sibling entries as separate entries only if it fits within the tokens limits." + # Arrange + entry = f""" +* Heading 1 +body line 1 +** Subheading 1.1 +body line 1.1 +* Heading 2 +body line 2 +** Subheading 2.1 +body line 2.1 +* Heading 3 +body line 3 +** Subheading 3.1 +body line 3.1 +""" + data = { + f"{tmp_path}": entry, + } + first_expected_entry = f""" +* Path: {tmp_path} +** Heading 1. + body line 1 + +*** Subheading 1.1. + body line 1.1 + +""".lstrip() + second_expected_entry = f""" +* Path: {tmp_path} +** Heading 2. + body line 2 + +*** Subheading 2.1. + body line 2.1 + +""".lstrip() + third_expected_entry = f""" +* Path: {tmp_path} +** Heading 3. + body line 3 + +*** Subheading 3.1. + body line 3.1 + +""".lstrip() + + # Act + # Extract Entries from specified Org files + # Max tokens = 30 is in the middle of 2 entry (24 tokens) and 3 entry (36 tokens) tokens boundary + # Where each sibling entry contains 12 tokens per sibling entry * 3 entries = 36 tokens + extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=30) + + # Assert + assert len(extracted_entries) == 3 + assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings" + assert extracted_entries[1].compiled == second_expected_entry, "Second entry includes children headings" + assert extracted_entries[2].compiled == third_expected_entry, "Third entry includes children headings" + + +def test_entry_with_body_to_entry(tmp_path): "Ensure entries with valid body text are loaded." # Arrange entry = f"""*** Heading @@ -107,19 +248,13 @@ def test_entry_with_body_to_jsonl(tmp_path): # Act # Extract Entries from specified Org files - entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data) - - # Process Each Entry from All Notes Files - jsonl_string = OrgToEntries.convert_org_entries_to_jsonl( - OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map) - ) - jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3) # Assert - assert len(jsonl_data) == 1 + assert len(entries) == 1 -def test_file_with_entry_after_intro_text_to_jsonl(tmp_path): +def test_file_with_entry_after_intro_text_to_entry(tmp_path): "Ensure intro text before any headings is indexed." # Arrange entry = f""" @@ -134,18 +269,13 @@ Intro text # Act # Extract Entries from specified Org files - entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data) - - # Process Each Entry from All Notes Files - entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries) - jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries) - jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3) # Assert - assert len(jsonl_data) == 2 + assert len(entries) == 2 -def test_file_with_no_headings_to_jsonl(tmp_path): +def test_file_with_no_headings_to_entry(tmp_path): "Ensure files with no heading, only body text are loaded." # Arrange entry = f""" @@ -158,15 +288,10 @@ def test_file_with_no_headings_to_jsonl(tmp_path): # Act # Extract Entries from specified Org files - entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data) - - # Process Each Entry from All Notes Files - entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries) - jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries) - jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3) # Assert - assert len(jsonl_data) == 1 + assert len(entries) == 1 def test_get_org_files(tmp_path): @@ -214,7 +339,8 @@ def test_extract_entries_with_different_level_headings(tmp_path): # Arrange entry = f""" * Heading 1 -** Heading 2 +** Sub-Heading 1.1 +* Heading 2 """ data = { f"{tmp_path}": entry, @@ -222,12 +348,14 @@ def test_extract_entries_with_different_level_headings(tmp_path): # Act # Extract Entries from specified Org files - entries, _ = OrgToEntries.extract_org_entries(org_files=data) + entries = OrgToEntries.extract_org_entries(org_files=data, index_heading_entries=True, max_tokens=3) + for entry in entries: + entry.raw = clean(f"{entry.raw}") # Assert assert len(entries) == 2 - assert f"{entries[0]}".startswith("* Heading 1") - assert f"{entries[1]}".startswith("** Heading 2") + assert entries[0].raw == "* Heading 1\n** Sub-Heading 1.1\n", "Ensure entry includes heading ancestory" + assert entries[1].raw == "* Heading 2\n" # Helper Functions @@ -237,3 +365,8 @@ def create_file(tmp_path, entry=None, filename="test.org"): if entry: org_file.write_text(entry) return org_file + + +def clean(entry): + "Remove properties from entry for easier comparison." + return re.sub(r"\n:PROPERTIES:(.*?):END:", "", entry, flags=re.DOTALL) diff --git a/tests/test_pdf_to_entries.py b/tests/test_pdf_to_entries.py index 62decdd7..a8c6aa43 100644 --- a/tests/test_pdf_to_entries.py +++ b/tests/test_pdf_to_entries.py @@ -1,4 +1,3 @@ -import json import os from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries @@ -15,16 +14,10 @@ def test_single_page_pdf_to_jsonl(): pdf_bytes = f.read() data = {"tests/data/pdf/singlepage.pdf": pdf_bytes} - entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data) - - # Process Each Entry from All Pdf Files - jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl( - PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map) - ) - jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + entries = PdfToEntries.extract_pdf_entries(pdf_files=data) # Assert - assert len(jsonl_data) == 1 + assert len(entries) == 1 def test_multi_page_pdf_to_jsonl(): @@ -35,16 +28,10 @@ def test_multi_page_pdf_to_jsonl(): pdf_bytes = f.read() data = {"tests/data/pdf/multipage.pdf": pdf_bytes} - entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data) - - # Process Each Entry from All Pdf Files - jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl( - PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map) - ) - jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + entries = PdfToEntries.extract_pdf_entries(pdf_files=data) # Assert - assert len(jsonl_data) == 6 + assert len(entries) == 6 def test_ocr_page_pdf_to_jsonl(): @@ -55,10 +42,7 @@ def test_ocr_page_pdf_to_jsonl(): pdf_bytes = f.read() data = {"tests/data/pdf/ocr_samples.pdf": pdf_bytes} - entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data) - - # Process Each Entry from All Pdf Files - entries = PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map) + entries = PdfToEntries.extract_pdf_entries(pdf_files=data) assert len(entries) == 1 assert "playing on a strip of marsh" in entries[0].raw diff --git a/tests/test_plaintext_to_entries.py b/tests/test_plaintext_to_entries.py index 53585177..41d841fc 100644 --- a/tests/test_plaintext_to_entries.py +++ b/tests/test_plaintext_to_entries.py @@ -1,4 +1,3 @@ -import json import os from pathlib import Path @@ -11,10 +10,10 @@ from khoj.utils.rawconfig import TextContentConfig def test_plaintext_file(tmp_path): "Convert files with no heading to jsonl." # Arrange - entry = f""" + raw_entry = f""" Hi, I am a plaintext file and I have some plaintext words. """ - plaintextfile = create_file(tmp_path, entry) + plaintextfile = create_file(tmp_path, raw_entry) filename = plaintextfile.stem @@ -22,25 +21,21 @@ def test_plaintext_file(tmp_path): # Extract Entries from specified plaintext files data = { - f"{plaintextfile}": entry, + f"{plaintextfile}": raw_entry, } - maps = PlaintextToEntries.convert_plaintext_entries_to_maps(entry_to_file_map=data) + entries = PlaintextToEntries.extract_plaintext_entries(entry_to_file_map=data) # Convert each entry.file to absolute path to make them JSON serializable - for map in maps: - map.file = str(Path(map.file).absolute()) - - # Process Each Entry from All Notes Files - jsonl_string = PlaintextToEntries.convert_entries_to_jsonl(maps) - jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + for entry in entries: + entry.file = str(Path(entry.file).absolute()) # Assert - assert len(jsonl_data) == 1 + assert len(entries) == 1 # Ensure raw entry with no headings do not get heading prefix prepended - assert not jsonl_data[0]["raw"].startswith("#") + assert not entries[0].raw.startswith("#") # Ensure compiled entry has filename prepended as top level heading - assert jsonl_data[0]["compiled"] == f"{filename}\n{entry}" + assert entries[0].compiled == f"{filename}\n{raw_entry}" def test_get_plaintext_files(tmp_path): @@ -98,11 +93,11 @@ def test_parse_html_plaintext_file(content_config, default_user: KhojUser): extracted_plaintext_files = get_plaintext_files(config=config) # Act - maps = PlaintextToEntries.convert_plaintext_entries_to_maps(extracted_plaintext_files) + entries = PlaintextToEntries.extract_plaintext_entries(extracted_plaintext_files) # Assert - assert len(maps) == 1 - assert "
" not in maps[0].raw + assert len(entries) == 1 + assert "
" not in entries[0].raw # Helper Functions diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 791ce91b..915425bf 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -57,18 +57,21 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db def test_text_search_setup_with_empty_file_creates_no_entries( - org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog + org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser ): # Arrange + existing_entries = Entry.objects.filter(user=default_user).count() data = get_org_files(org_config_with_only_new_file) # Act # Generate notes embeddings during asymmetric setup - with caplog.at_level(logging.INFO): - text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) + text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) # Assert - assert "Deleted 8 entries. Created 0 new entries for user " in caplog.records[-1].message + updated_entries = Entry.objects.filter(user=default_user).count() + + assert existing_entries == 2 + assert updated_entries == 0 verify_embeddings(0, default_user) @@ -78,6 +81,7 @@ def test_text_indexer_deletes_embedding_before_regenerate( content_config: ContentConfig, default_user: KhojUser, caplog ): # Arrange + existing_entries = Entry.objects.filter(user=default_user).count() org_config = LocalOrgConfig.objects.filter(user=default_user).first() data = get_org_files(org_config) @@ -87,30 +91,18 @@ def test_text_indexer_deletes_embedding_before_regenerate( text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) # Assert + updated_entries = Entry.objects.filter(user=default_user).count() + assert existing_entries == 2 + assert updated_entries == 2 assert "Deleting all entries for file type org" in caplog.text - assert "Deleted 8 entries. Created 13 new entries for user " in caplog.records[-1].message - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.django_db -def test_text_search_setup_batch_processes(content_config: ContentConfig, default_user: KhojUser, caplog): - # Arrange - org_config = LocalOrgConfig.objects.filter(user=default_user).first() - data = get_org_files(org_config) - - # Act - # Generate notes embeddings during asymmetric setup - with caplog.at_level(logging.DEBUG): - text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) - - # Assert - assert "Deleted 8 entries. Created 13 new entries for user " in caplog.records[-1].message + assert "Deleted 2 entries. Created 2 new entries for user " in caplog.records[-1].message # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog): # Arrange + existing_entries = Entry.objects.filter(user=default_user) org_config = LocalOrgConfig.objects.filter(user=default_user).first() data = get_org_files(org_config) @@ -127,6 +119,10 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def final_logs = caplog.text # Assert + updated_entries = Entry.objects.filter(user=default_user) + for entry in updated_entries: + assert entry in existing_entries + assert len(existing_entries) == len(updated_entries) assert "Deleting all entries for file type org" in initial_logs assert "Deleting all entries for file type org" not in final_logs @@ -192,7 +188,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon # Assert assert ( - "Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message + "Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message ), "new entry not split by max tokens" @@ -250,16 +246,15 @@ conda activate khoj # Assert assert ( - "Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message + "Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message ), "new entry not split by max tokens" # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_regenerate_index_with_new_entry( - content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog -): +def test_regenerate_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser): # Arrange + existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) org_config = LocalOrgConfig.objects.filter(user=default_user).first() initial_data = get_org_files(org_config) @@ -271,28 +266,34 @@ def test_regenerate_index_with_new_entry( final_data = get_org_files(org_config) # Act - with caplog.at_level(logging.INFO): - text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user) - initial_logs = caplog.text - caplog.clear() # Clear logs + text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user) + updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) # regenerate notes jsonl, model embeddings and model to include entry from new file - with caplog.at_level(logging.INFO): - text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user) - final_logs = caplog.text + text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user) + updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) # Assert - assert "Deleted 8 entries. Created 13 new entries for user " in initial_logs - assert "Deleted 13 entries. Created 14 new entries for user " in final_logs - verify_embeddings(14, default_user) + for entry in updated_entries1: + assert entry in updated_entries2 + + assert not any([new_org_file.name in entry for entry in updated_entries1]) + assert not any([new_org_file.name in entry for entry in existing_entries]) + assert any([new_org_file.name in entry for entry in updated_entries2]) + + assert any( + ["Saw a super cute video of a chihuahua doing the Tango on Youtube" in entry for entry in updated_entries2] + ) + verify_embeddings(3, default_user) # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db def test_update_index_with_duplicate_entries_in_stable_order( - org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog + org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser ): # Arrange + existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) # Insert org-mode entries with same compiled form into new org file @@ -304,30 +305,33 @@ def test_update_index_with_duplicate_entries_in_stable_order( # Act # generate embeddings, entries, notes model from scratch after adding new org-mode file - with caplog.at_level(logging.INFO): - text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) - initial_logs = caplog.text - caplog.clear() # Clear logs + text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) + updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) data = get_org_files(org_config_with_only_new_file) # update embeddings, entries, notes model with no new changes - with caplog.at_level(logging.INFO): - text_search.setup(OrgToEntries, data, regenerate=False, user=default_user) - final_logs = caplog.text + text_search.setup(OrgToEntries, data, regenerate=False, user=default_user) + updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) # Assert # verify only 1 entry added even if there are multiple duplicate entries - assert "Deleted 8 entries. Created 1 new entries for user " in initial_logs - assert "Deleted 0 entries. Created 0 new entries for user " in final_logs + for entry in existing_entries: + assert entry not in updated_entries1 + for entry in updated_entries1: + assert entry in updated_entries2 + + assert len(existing_entries) == 2 + assert len(updated_entries1) == len(updated_entries2) verify_embeddings(1, default_user) # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog): +def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser): # Arrange + existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) # Insert org-mode entries with same compiled form into new org file @@ -344,33 +348,34 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg # Act # load embeddings, entries, notes model after adding new org file with 2 entries - with caplog.at_level(logging.INFO): - text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user) - initial_logs = caplog.text - caplog.clear() # Clear logs + text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user) + updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) - with caplog.at_level(logging.INFO): - text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user) - final_logs = caplog.text + text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user) + updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) # Assert - # verify only 1 entry added even if there are multiple duplicate entries - assert "Deleted 8 entries. Created 2 new entries for user " in initial_logs - assert "Deleted 1 entries. Created 0 new entries for user " in final_logs + for entry in existing_entries: + assert entry not in updated_entries1 + + # verify the entry in updated_entries2 is a subset of updated_entries1 + for entry in updated_entries1: + assert entry not in updated_entries2 + + for entry in updated_entries2: + assert entry in updated_entries1[0] verify_embeddings(1, default_user) # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog): +def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser): # Arrange + existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) org_config = LocalOrgConfig.objects.filter(user=default_user).first() data = get_org_files(org_config) - with caplog.at_level(logging.INFO): - text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) - initial_logs = caplog.text - caplog.clear() # Clear logs + text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) # append org-mode entry to first org input file in config with open(new_org_file, "w") as f: @@ -381,14 +386,14 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file # Act # update embeddings, entries with the newly added note - with caplog.at_level(logging.INFO): - text_search.setup(OrgToEntries, data, regenerate=False, user=default_user) - final_logs = caplog.text + text_search.setup(OrgToEntries, data, regenerate=False, user=default_user) + updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) # Assert - assert "Deleted 8 entries. Created 13 new entries for user " in initial_logs - assert "Deleted 0 entries. Created 1 new entries for user " in final_logs - verify_embeddings(14, default_user) + for entry in existing_entries: + assert entry not in updated_entries1 + assert len(updated_entries1) == len(existing_entries) + 1 + verify_embeddings(3, default_user) # ----------------------------------------------------------------------------------------------------