mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Update Text Chunking Strategy to Improve Search Context (#645)
## Major - Parse markdown, org parent entries as single entry if fit within max tokens - Parse a file as single entry if it fits with max token limits - Add parent heading ancestry to extracted markdown entries for context - Chunk text in preference order of para, sentence, word, character ## Minor - Create wrapper function to get entries from org, md, pdf & text files - Remove unused Entry to Jsonl converter from text to entry class, tests - Dedupe code by using single func to process an org file into entries Resolves #620
This commit is contained in:
commit
11ce3e2268
15 changed files with 704 additions and 393 deletions
|
@ -1,14 +1,13 @@
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import urllib3
|
import urllib3
|
||||||
|
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import KhojUser
|
from khoj.database.models import KhojUser
|
||||||
from khoj.processor.content.text_to_entries import TextToEntries
|
from khoj.processor.content.text_to_entries import TextToEntries
|
||||||
from khoj.utils.constants import empty_escape_sequences
|
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
|
|
||||||
|
@ -31,15 +30,14 @@ class MarkdownToEntries(TextToEntries):
|
||||||
else:
|
else:
|
||||||
deletion_file_names = None
|
deletion_file_names = None
|
||||||
|
|
||||||
|
max_tokens = 256
|
||||||
# Extract Entries from specified Markdown files
|
# Extract Entries from specified Markdown files
|
||||||
with timer("Parse entries from Markdown files into dictionaries", logger):
|
with timer("Extract entries from specified Markdown files", logger):
|
||||||
current_entries = MarkdownToEntries.convert_markdown_entries_to_maps(
|
current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
|
||||||
*MarkdownToEntries.extract_markdown_entries(files)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Split entries by max tokens supported by model
|
# Split entries by max tokens supported by model
|
||||||
with timer("Split entries by max token size supported by model", logger):
|
with timer("Split entries by max token size supported by model", logger):
|
||||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens)
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
|
@ -57,48 +55,84 @@ class MarkdownToEntries(TextToEntries):
|
||||||
return num_new_embeddings, num_deleted_embeddings
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_markdown_entries(markdown_files):
|
def extract_markdown_entries(markdown_files, max_tokens=256) -> List[Entry]:
|
||||||
"Extract entries by heading from specified Markdown files"
|
"Extract entries by heading from specified Markdown files"
|
||||||
|
entries: List[str] = []
|
||||||
# Regex to extract Markdown Entries by Heading
|
entry_to_file_map: List[Tuple[str, str]] = []
|
||||||
|
|
||||||
entries = []
|
|
||||||
entry_to_file_map = []
|
|
||||||
for markdown_file in markdown_files:
|
for markdown_file in markdown_files:
|
||||||
try:
|
try:
|
||||||
markdown_content = markdown_files[markdown_file]
|
markdown_content = markdown_files[markdown_file]
|
||||||
entries, entry_to_file_map = MarkdownToEntries.process_single_markdown_file(
|
entries, entry_to_file_map = MarkdownToEntries.process_single_markdown_file(
|
||||||
markdown_content, markdown_file, entries, entry_to_file_map
|
markdown_content, markdown_file, entries, entry_to_file_map, max_tokens
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Unable to process file: {markdown_file}. This file will not be indexed.")
|
logger.error(
|
||||||
logger.warning(e, exc_info=True)
|
f"Unable to process file: {markdown_file}. This file will not be indexed.\n{e}", exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
return entries, dict(entry_to_file_map)
|
return MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process_single_markdown_file(
|
def process_single_markdown_file(
|
||||||
markdown_content: str, markdown_file: Path, entries: List, entry_to_file_map: List
|
markdown_content: str,
|
||||||
):
|
markdown_file: str,
|
||||||
markdown_heading_regex = r"^#"
|
entries: List[str],
|
||||||
|
entry_to_file_map: List[Tuple[str, str]],
|
||||||
|
max_tokens=256,
|
||||||
|
ancestry: Dict[int, str] = {},
|
||||||
|
) -> Tuple[List[str], List[Tuple[str, str]]]:
|
||||||
|
# Prepend the markdown section's heading ancestry
|
||||||
|
ancestry_string = "\n".join([f"{'#' * key} {ancestry[key]}" for key in sorted(ancestry.keys())])
|
||||||
|
markdown_content_with_ancestry = f"{ancestry_string}{markdown_content}"
|
||||||
|
|
||||||
markdown_entries_per_file = []
|
# If content is small or content has no children headings, save it as a single entry
|
||||||
any_headings = re.search(markdown_heading_regex, markdown_content, flags=re.MULTILINE)
|
if len(TextToEntries.tokenizer(markdown_content_with_ancestry)) <= max_tokens or not re.search(
|
||||||
for entry in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE):
|
rf"^#{{{len(ancestry)+1},}}\s", markdown_content, flags=re.MULTILINE
|
||||||
# Add heading level as the regex split removed it from entries with headings
|
):
|
||||||
prefix = "#" if entry.startswith("#") else "# " if any_headings else ""
|
entry_to_file_map += [(markdown_content_with_ancestry, markdown_file)]
|
||||||
stripped_entry = entry.strip(empty_escape_sequences)
|
entries.extend([markdown_content_with_ancestry])
|
||||||
if stripped_entry != "":
|
return entries, entry_to_file_map
|
||||||
markdown_entries_per_file.append(f"{prefix}{stripped_entry}")
|
|
||||||
|
# Split by next heading level present in the entry
|
||||||
|
next_heading_level = len(ancestry)
|
||||||
|
sections: List[str] = []
|
||||||
|
while len(sections) < 2:
|
||||||
|
next_heading_level += 1
|
||||||
|
sections = re.split(rf"(\n|^)(?=[#]{{{next_heading_level}}} .+\n?)", markdown_content, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
for section in sections:
|
||||||
|
# Skip empty sections
|
||||||
|
if section.strip() == "":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract the section body and (when present) the heading
|
||||||
|
current_ancestry = ancestry.copy()
|
||||||
|
first_line = [line for line in section.split("\n") if line.strip() != ""][0]
|
||||||
|
if re.search(rf"^#{{{next_heading_level}}} ", first_line):
|
||||||
|
# Extract the section body without the heading
|
||||||
|
current_section_body = "\n".join(section.split(first_line)[1:])
|
||||||
|
# Parse the section heading into current section ancestry
|
||||||
|
current_section_title = first_line[next_heading_level:].strip()
|
||||||
|
current_ancestry[next_heading_level] = current_section_title
|
||||||
|
else:
|
||||||
|
current_section_body = section
|
||||||
|
|
||||||
|
# Recurse down children of the current entry
|
||||||
|
MarkdownToEntries.process_single_markdown_file(
|
||||||
|
current_section_body,
|
||||||
|
markdown_file,
|
||||||
|
entries,
|
||||||
|
entry_to_file_map,
|
||||||
|
max_tokens,
|
||||||
|
current_ancestry,
|
||||||
|
)
|
||||||
|
|
||||||
entry_to_file_map += zip(markdown_entries_per_file, [markdown_file] * len(markdown_entries_per_file))
|
|
||||||
entries.extend(markdown_entries_per_file)
|
|
||||||
return entries, entry_to_file_map
|
return entries, entry_to_file_map
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_markdown_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
|
def convert_markdown_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
|
||||||
"Convert each Markdown entries into a dictionary"
|
"Convert each Markdown entries into a dictionary"
|
||||||
entries = []
|
entries: List[Entry] = []
|
||||||
for parsed_entry in parsed_entries:
|
for parsed_entry in parsed_entries:
|
||||||
raw_filename = entry_to_file_map[parsed_entry]
|
raw_filename = entry_to_file_map[parsed_entry]
|
||||||
|
|
||||||
|
@ -108,13 +142,12 @@ class MarkdownToEntries(TextToEntries):
|
||||||
entry_filename = urllib3.util.parse_url(raw_filename).url
|
entry_filename = urllib3.util.parse_url(raw_filename).url
|
||||||
else:
|
else:
|
||||||
entry_filename = str(Path(raw_filename))
|
entry_filename = str(Path(raw_filename))
|
||||||
stem = Path(raw_filename).stem
|
|
||||||
|
|
||||||
heading = parsed_entry.splitlines()[0] if re.search("^#+\s", parsed_entry) else ""
|
heading = parsed_entry.splitlines()[0] if re.search("^#+\s", parsed_entry) else ""
|
||||||
# Append base filename to compiled entry for context to model
|
# Append base filename to compiled entry for context to model
|
||||||
# Increment heading level for heading entries and make filename as its top level heading
|
# Increment heading level for heading entries and make filename as its top level heading
|
||||||
prefix = f"# {stem}\n#" if heading else f"# {stem}\n"
|
prefix = f"# {entry_filename}\n#" if heading else f"# {entry_filename}\n"
|
||||||
compiled_entry = f"{entry_filename}\n{prefix}{parsed_entry}"
|
compiled_entry = f"{prefix}{parsed_entry}"
|
||||||
entries.append(
|
entries.append(
|
||||||
Entry(
|
Entry(
|
||||||
compiled=compiled_entry,
|
compiled=compiled_entry,
|
||||||
|
@ -127,8 +160,3 @@ class MarkdownToEntries(TextToEntries):
|
||||||
logger.debug(f"Converted {len(parsed_entries)} markdown entries to dictionaries")
|
logger.debug(f"Converted {len(parsed_entries)} markdown entries to dictionaries")
|
||||||
|
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_markdown_maps_to_jsonl(entries: List[Entry]):
|
|
||||||
"Convert each Markdown entry to JSON and collate as JSONL"
|
|
||||||
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import KhojUser
|
from khoj.database.models import KhojUser
|
||||||
from khoj.processor.content.org_mode import orgnode
|
from khoj.processor.content.org_mode import orgnode
|
||||||
|
from khoj.processor.content.org_mode.orgnode import Orgnode
|
||||||
from khoj.processor.content.text_to_entries import TextToEntries
|
from khoj.processor.content.text_to_entries import TextToEntries
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
|
@ -21,9 +23,6 @@ class OrgToEntries(TextToEntries):
|
||||||
def process(
|
def process(
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||||
) -> Tuple[int, int]:
|
) -> Tuple[int, int]:
|
||||||
# Extract required fields from config
|
|
||||||
index_heading_entries = False
|
|
||||||
|
|
||||||
if not full_corpus:
|
if not full_corpus:
|
||||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||||
files_to_process = set(files) - deletion_file_names
|
files_to_process = set(files) - deletion_file_names
|
||||||
|
@ -32,14 +31,12 @@ class OrgToEntries(TextToEntries):
|
||||||
deletion_file_names = None
|
deletion_file_names = None
|
||||||
|
|
||||||
# Extract Entries from specified Org files
|
# Extract Entries from specified Org files
|
||||||
with timer("Parse entries from org files into OrgNode objects", logger):
|
max_tokens = 256
|
||||||
entry_nodes, file_to_entries = self.extract_org_entries(files)
|
with timer("Extract entries from specified Org files", logger):
|
||||||
|
current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
|
||||||
with timer("Convert OrgNodes into list of entries", logger):
|
|
||||||
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
|
|
||||||
|
|
||||||
with timer("Split entries by max token size supported by model", logger):
|
with timer("Split entries by max token size supported by model", logger):
|
||||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=max_tokens)
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
|
@ -57,93 +54,165 @@ class OrgToEntries(TextToEntries):
|
||||||
return num_new_embeddings, num_deleted_embeddings
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_org_entries(org_files: dict[str, str]):
|
def extract_org_entries(
|
||||||
|
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=256
|
||||||
|
) -> List[Entry]:
|
||||||
"Extract entries from specified Org files"
|
"Extract entries from specified Org files"
|
||||||
entries = []
|
entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
|
||||||
entry_to_file_map: List[Tuple[orgnode.Orgnode, str]] = []
|
return OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map, index_heading_entries)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_org_nodes(org_files: dict[str, str], max_tokens) -> Tuple[List[List[Orgnode]], Dict[Orgnode, str]]:
|
||||||
|
"Extract org nodes from specified org files"
|
||||||
|
entries: List[List[Orgnode]] = []
|
||||||
|
entry_to_file_map: List[Tuple[Orgnode, str]] = []
|
||||||
for org_file in org_files:
|
for org_file in org_files:
|
||||||
filename = org_file
|
|
||||||
file = org_files[org_file]
|
|
||||||
try:
|
try:
|
||||||
org_file_entries = orgnode.makelist(file, filename)
|
org_content = org_files[org_file]
|
||||||
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
|
entries, entry_to_file_map = OrgToEntries.process_single_org_file(
|
||||||
entries.extend(org_file_entries)
|
org_content, org_file, entries, entry_to_file_map, max_tokens
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Unable to process file: {org_file}. This file will not be indexed.")
|
logger.error(f"Unable to process file: {org_file}. Skipped indexing it.\nError; {e}", exc_info=True)
|
||||||
logger.warning(e, exc_info=True)
|
|
||||||
|
|
||||||
return entries, dict(entry_to_file_map)
|
return entries, dict(entry_to_file_map)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process_single_org_file(org_content: str, org_file: str, entries: List, entry_to_file_map: List):
|
def process_single_org_file(
|
||||||
# Process single org file. The org parser assumes that the file is a single org file and reads it from a buffer. We'll split the raw conetnt of this file by new line to mimic the same behavior.
|
org_content: str,
|
||||||
try:
|
org_file: str,
|
||||||
org_file_entries = orgnode.makelist(org_content, org_file)
|
entries: List[List[Orgnode]],
|
||||||
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
|
entry_to_file_map: List[Tuple[Orgnode, str]],
|
||||||
entries.extend(org_file_entries)
|
max_tokens=256,
|
||||||
return entries, entry_to_file_map
|
ancestry: Dict[int, str] = {},
|
||||||
except Exception as e:
|
) -> Tuple[List[List[Orgnode]], List[Tuple[Orgnode, str]]]:
|
||||||
logger.error(f"Error processing file: {org_file} with error: {e}", exc_info=True)
|
"""Parse org_content from org_file into OrgNode entries
|
||||||
|
|
||||||
|
Recurse down org file entries, one heading level at a time,
|
||||||
|
until reach a leaf entry or the current entry tree fits max_tokens.
|
||||||
|
|
||||||
|
Parse recursion terminating entry (trees) into (a list of) OrgNode objects.
|
||||||
|
"""
|
||||||
|
# Prepend the org section's heading ancestry
|
||||||
|
ancestry_string = "\n".join([f"{'*' * key} {ancestry[key]}" for key in sorted(ancestry.keys())])
|
||||||
|
org_content_with_ancestry = f"{ancestry_string}{org_content}"
|
||||||
|
|
||||||
|
# If content is small or content has no children headings, save it as a single entry
|
||||||
|
# Note: This is the terminating condition for this recursive function
|
||||||
|
if len(TextToEntries.tokenizer(org_content_with_ancestry)) <= max_tokens or not re.search(
|
||||||
|
rf"^\*{{{len(ancestry)+1},}}\s", org_content, re.MULTILINE
|
||||||
|
):
|
||||||
|
orgnode_content_with_ancestry = orgnode.makelist(org_content_with_ancestry, org_file)
|
||||||
|
entry_to_file_map += zip(orgnode_content_with_ancestry, [org_file] * len(orgnode_content_with_ancestry))
|
||||||
|
entries.extend([orgnode_content_with_ancestry])
|
||||||
return entries, entry_to_file_map
|
return entries, entry_to_file_map
|
||||||
|
|
||||||
|
# Split this entry tree into sections by the next heading level in it
|
||||||
|
# Increment heading level until able to split entry into sections
|
||||||
|
# A successful split will result in at least 2 sections
|
||||||
|
next_heading_level = len(ancestry)
|
||||||
|
sections: List[str] = []
|
||||||
|
while len(sections) < 2:
|
||||||
|
next_heading_level += 1
|
||||||
|
sections = re.split(rf"(\n|^)(?=[*]{{{next_heading_level}}} .+\n?)", org_content, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# Recurse down each non-empty section after parsing its body, heading and ancestry
|
||||||
|
for section in sections:
|
||||||
|
# Skip empty sections
|
||||||
|
if section.strip() == "":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract the section body and (when present) the heading
|
||||||
|
current_ancestry = ancestry.copy()
|
||||||
|
first_non_empty_line = [line for line in section.split("\n") if line.strip() != ""][0]
|
||||||
|
# If first non-empty line is a heading with expected heading level
|
||||||
|
if re.search(rf"^\*{{{next_heading_level}}}\s", first_non_empty_line):
|
||||||
|
# Extract the section body without the heading
|
||||||
|
current_section_body = "\n".join(section.split(first_non_empty_line)[1:])
|
||||||
|
# Parse the section heading into current section ancestry
|
||||||
|
current_section_title = first_non_empty_line[next_heading_level:].strip()
|
||||||
|
current_ancestry[next_heading_level] = current_section_title
|
||||||
|
# Else process the section as just body text
|
||||||
|
else:
|
||||||
|
current_section_body = section
|
||||||
|
|
||||||
|
# Recurse down children of the current entry
|
||||||
|
OrgToEntries.process_single_org_file(
|
||||||
|
current_section_body,
|
||||||
|
org_file,
|
||||||
|
entries,
|
||||||
|
entry_to_file_map,
|
||||||
|
max_tokens,
|
||||||
|
current_ancestry,
|
||||||
|
)
|
||||||
|
|
||||||
|
return entries, entry_to_file_map
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_org_nodes_to_entries(
|
def convert_org_nodes_to_entries(
|
||||||
parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False
|
parsed_entries: List[List[Orgnode]],
|
||||||
|
entry_to_file_map: Dict[Orgnode, str],
|
||||||
|
index_heading_entries: bool = False,
|
||||||
) -> List[Entry]:
|
) -> List[Entry]:
|
||||||
"Convert Org-Mode nodes into list of Entry objects"
|
"""
|
||||||
|
Convert OrgNode lists into list of Entry objects
|
||||||
|
|
||||||
|
Each list of OrgNodes is a parsed parent org tree or leaf node.
|
||||||
|
Convert each list of these OrgNodes into a single Entry.
|
||||||
|
"""
|
||||||
entries: List[Entry] = []
|
entries: List[Entry] = []
|
||||||
for parsed_entry in parsed_entries:
|
for entry_group in parsed_entries:
|
||||||
if not parsed_entry.hasBody and not index_heading_entries:
|
entry_heading, entry_compiled, entry_raw = "", "", ""
|
||||||
# Ignore title notes i.e notes with just headings and empty body
|
for parsed_entry in entry_group:
|
||||||
continue
|
if not parsed_entry.hasBody and not index_heading_entries:
|
||||||
|
# Ignore title notes i.e notes with just headings and empty body
|
||||||
|
continue
|
||||||
|
|
||||||
todo_str = f"{parsed_entry.todo} " if parsed_entry.todo else ""
|
todo_str = f"{parsed_entry.todo} " if parsed_entry.todo else ""
|
||||||
|
|
||||||
# Prepend ancestor headings, filename as top heading to entry for context
|
# Set base level to current org-node tree's root heading level
|
||||||
ancestors_trail = " / ".join(parsed_entry.ancestors) or Path(entry_to_file_map[parsed_entry])
|
if not entry_heading and parsed_entry.level > 0:
|
||||||
if parsed_entry.heading:
|
base_level = parsed_entry.level
|
||||||
heading = f"* Path: {ancestors_trail}\n** {todo_str}{parsed_entry.heading}."
|
# Indent entry by 1 heading level as ancestry is prepended as top level heading
|
||||||
else:
|
heading = f"{'*' * (parsed_entry.level-base_level+2)} {todo_str}" if parsed_entry.level > 0 else ""
|
||||||
heading = f"* Path: {ancestors_trail}."
|
if parsed_entry.heading:
|
||||||
|
heading += f"{parsed_entry.heading}."
|
||||||
|
|
||||||
compiled = heading
|
# Prepend ancestor headings, filename as top heading to root parent entry for context
|
||||||
if state.verbose > 2:
|
# Children nodes do not need ancestors trail as root parent node will have it
|
||||||
logger.debug(f"Title: {heading}")
|
if not entry_heading:
|
||||||
|
ancestors_trail = " / ".join(parsed_entry.ancestors) or Path(entry_to_file_map[parsed_entry])
|
||||||
|
heading = f"* Path: {ancestors_trail}\n{heading}" if heading else f"* Path: {ancestors_trail}."
|
||||||
|
|
||||||
if parsed_entry.tags:
|
compiled = heading
|
||||||
tags_str = " ".join(parsed_entry.tags)
|
|
||||||
compiled += f"\t {tags_str}."
|
|
||||||
if state.verbose > 2:
|
|
||||||
logger.debug(f"Tags: {tags_str}")
|
|
||||||
|
|
||||||
if parsed_entry.closed:
|
if parsed_entry.tags:
|
||||||
compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.'
|
tags_str = " ".join(parsed_entry.tags)
|
||||||
if state.verbose > 2:
|
compiled += f"\t {tags_str}."
|
||||||
logger.debug(f'Closed: {parsed_entry.closed.strftime("%Y-%m-%d")}')
|
|
||||||
|
|
||||||
if parsed_entry.scheduled:
|
if parsed_entry.closed:
|
||||||
compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.'
|
compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.'
|
||||||
if state.verbose > 2:
|
|
||||||
logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}')
|
|
||||||
|
|
||||||
if parsed_entry.hasBody:
|
if parsed_entry.scheduled:
|
||||||
compiled += f"\n {parsed_entry.body}"
|
compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.'
|
||||||
if state.verbose > 2:
|
|
||||||
logger.debug(f"Body: {parsed_entry.body}")
|
|
||||||
|
|
||||||
if compiled:
|
if parsed_entry.hasBody:
|
||||||
|
compiled += f"\n {parsed_entry.body}"
|
||||||
|
|
||||||
|
# Add the sub-entry contents to the entry
|
||||||
|
entry_compiled += f"{compiled}"
|
||||||
|
entry_raw += f"{parsed_entry}"
|
||||||
|
if not entry_heading:
|
||||||
|
entry_heading = heading
|
||||||
|
|
||||||
|
if entry_compiled:
|
||||||
entries.append(
|
entries.append(
|
||||||
Entry(
|
Entry(
|
||||||
compiled=compiled,
|
compiled=entry_compiled,
|
||||||
raw=f"{parsed_entry}",
|
raw=entry_raw,
|
||||||
heading=f"{heading}",
|
heading=f"{entry_heading}",
|
||||||
file=f"{entry_to_file_map[parsed_entry]}",
|
file=f"{entry_to_file_map[parsed_entry]}",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str:
|
|
||||||
"Convert each Org-Mode entry to JSON and collate as JSONL"
|
|
||||||
return "".join([f"{entry_dict.to_json()}\n" for entry_dict in entries])
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ import datetime
|
||||||
import re
|
import re
|
||||||
from os.path import relpath
|
from os.path import relpath
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
indent_regex = re.compile(r"^ *")
|
indent_regex = re.compile(r"^ *")
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ def makelist_with_filepath(filename):
|
||||||
return makelist(f, filename)
|
return makelist(f, filename)
|
||||||
|
|
||||||
|
|
||||||
def makelist(file, filename):
|
def makelist(file, filename) -> List["Orgnode"]:
|
||||||
"""
|
"""
|
||||||
Read an org-mode file and return a list of Orgnode objects
|
Read an org-mode file and return a list of Orgnode objects
|
||||||
created from this file.
|
created from this file.
|
||||||
|
@ -80,16 +80,16 @@ def makelist(file, filename):
|
||||||
} # populated from #+SEQ_TODO line
|
} # populated from #+SEQ_TODO line
|
||||||
level = ""
|
level = ""
|
||||||
heading = ""
|
heading = ""
|
||||||
ancestor_headings = []
|
ancestor_headings: List[str] = []
|
||||||
bodytext = ""
|
bodytext = ""
|
||||||
introtext = ""
|
introtext = ""
|
||||||
tags = list() # set of all tags in headline
|
tags: List[str] = list() # set of all tags in headline
|
||||||
closed_date = ""
|
closed_date: datetime.date = None
|
||||||
sched_date = ""
|
sched_date: datetime.date = None
|
||||||
deadline_date = ""
|
deadline_date: datetime.date = None
|
||||||
logbook = list()
|
logbook: List[Tuple[datetime.datetime, datetime.datetime]] = list()
|
||||||
nodelist: List[Orgnode] = list()
|
nodelist: List[Orgnode] = list()
|
||||||
property_map = dict()
|
property_map: Dict[str, str] = dict()
|
||||||
in_properties_drawer = False
|
in_properties_drawer = False
|
||||||
in_logbook_drawer = False
|
in_logbook_drawer = False
|
||||||
file_title = f"{filename}"
|
file_title = f"{filename}"
|
||||||
|
@ -102,13 +102,13 @@ def makelist(file, filename):
|
||||||
thisNode = Orgnode(level, heading, bodytext, tags, ancestor_headings)
|
thisNode = Orgnode(level, heading, bodytext, tags, ancestor_headings)
|
||||||
if closed_date:
|
if closed_date:
|
||||||
thisNode.closed = closed_date
|
thisNode.closed = closed_date
|
||||||
closed_date = ""
|
closed_date = None
|
||||||
if sched_date:
|
if sched_date:
|
||||||
thisNode.scheduled = sched_date
|
thisNode.scheduled = sched_date
|
||||||
sched_date = ""
|
sched_date = None
|
||||||
if deadline_date:
|
if deadline_date:
|
||||||
thisNode.deadline = deadline_date
|
thisNode.deadline = deadline_date
|
||||||
deadline_date = ""
|
deadline_date = None
|
||||||
if logbook:
|
if logbook:
|
||||||
thisNode.logbook = logbook
|
thisNode.logbook = logbook
|
||||||
logbook = list()
|
logbook = list()
|
||||||
|
@ -116,7 +116,7 @@ def makelist(file, filename):
|
||||||
nodelist.append(thisNode)
|
nodelist.append(thisNode)
|
||||||
property_map = {"LINE": f"file:{normalize_filename(filename)}::{ctr}"}
|
property_map = {"LINE": f"file:{normalize_filename(filename)}::{ctr}"}
|
||||||
previous_level = level
|
previous_level = level
|
||||||
previous_heading = heading
|
previous_heading: str = heading
|
||||||
level = heading_search.group(1)
|
level = heading_search.group(1)
|
||||||
heading = heading_search.group(2)
|
heading = heading_search.group(2)
|
||||||
bodytext = ""
|
bodytext = ""
|
||||||
|
@ -495,12 +495,13 @@ class Orgnode(object):
|
||||||
if self._priority:
|
if self._priority:
|
||||||
n = n + "[#" + self._priority + "] "
|
n = n + "[#" + self._priority + "] "
|
||||||
n = n + self._heading
|
n = n + self._heading
|
||||||
n = "%-60s " % n # hack - tags will start in column 62
|
if self._tags:
|
||||||
closecolon = ""
|
n = "%-60s " % n # hack - tags will start in column 62
|
||||||
for t in self._tags:
|
closecolon = ""
|
||||||
n = n + ":" + t
|
for t in self._tags:
|
||||||
closecolon = ":"
|
n = n + ":" + t
|
||||||
n = n + closecolon
|
closecolon = ":"
|
||||||
|
n = n + closecolon
|
||||||
n = n + "\n"
|
n = n + "\n"
|
||||||
|
|
||||||
# Get body indentation from first line of body
|
# Get body indentation from first line of body
|
||||||
|
|
|
@ -32,8 +32,8 @@ class PdfToEntries(TextToEntries):
|
||||||
deletion_file_names = None
|
deletion_file_names = None
|
||||||
|
|
||||||
# Extract Entries from specified Pdf files
|
# Extract Entries from specified Pdf files
|
||||||
with timer("Parse entries from PDF files into dictionaries", logger):
|
with timer("Extract entries from specified PDF files", logger):
|
||||||
current_entries = PdfToEntries.convert_pdf_entries_to_maps(*PdfToEntries.extract_pdf_entries(files))
|
current_entries = PdfToEntries.extract_pdf_entries(files)
|
||||||
|
|
||||||
# Split entries by max tokens supported by model
|
# Split entries by max tokens supported by model
|
||||||
with timer("Split entries by max token size supported by model", logger):
|
with timer("Split entries by max token size supported by model", logger):
|
||||||
|
@ -55,11 +55,11 @@ class PdfToEntries(TextToEntries):
|
||||||
return num_new_embeddings, num_deleted_embeddings
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_pdf_entries(pdf_files):
|
def extract_pdf_entries(pdf_files) -> List[Entry]:
|
||||||
"""Extract entries by page from specified PDF files"""
|
"""Extract entries by page from specified PDF files"""
|
||||||
|
|
||||||
entries = []
|
entries: List[str] = []
|
||||||
entry_to_location_map = []
|
entry_to_location_map: List[Tuple[str, str]] = []
|
||||||
for pdf_file in pdf_files:
|
for pdf_file in pdf_files:
|
||||||
try:
|
try:
|
||||||
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path
|
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path
|
||||||
|
@ -83,7 +83,7 @@ class PdfToEntries(TextToEntries):
|
||||||
if os.path.exists(f"{tmp_file}"):
|
if os.path.exists(f"{tmp_file}"):
|
||||||
os.remove(f"{tmp_file}")
|
os.remove(f"{tmp_file}")
|
||||||
|
|
||||||
return entries, dict(entry_to_location_map)
|
return PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_pdf_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
|
def convert_pdf_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
|
||||||
|
@ -106,8 +106,3 @@ class PdfToEntries(TextToEntries):
|
||||||
logger.debug(f"Converted {len(parsed_entries)} PDF entries to dictionaries")
|
logger.debug(f"Converted {len(parsed_entries)} PDF entries to dictionaries")
|
||||||
|
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_pdf_maps_to_jsonl(entries: List[Entry]):
|
|
||||||
"Convert each PDF entry to JSON and collate as JSONL"
|
|
||||||
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
|
||||||
|
|
|
@ -42,8 +42,8 @@ class PlaintextToEntries(TextToEntries):
|
||||||
logger.warning(e, exc_info=True)
|
logger.warning(e, exc_info=True)
|
||||||
|
|
||||||
# Extract Entries from specified plaintext files
|
# Extract Entries from specified plaintext files
|
||||||
with timer("Parse entries from plaintext files", logger):
|
with timer("Parse entries from specified Plaintext files", logger):
|
||||||
current_entries = PlaintextToEntries.convert_plaintext_entries_to_maps(files)
|
current_entries = PlaintextToEntries.extract_plaintext_entries(files)
|
||||||
|
|
||||||
# Split entries by max tokens supported by model
|
# Split entries by max tokens supported by model
|
||||||
with timer("Split entries by max token size supported by model", logger):
|
with timer("Split entries by max token size supported by model", logger):
|
||||||
|
@ -74,7 +74,7 @@ class PlaintextToEntries(TextToEntries):
|
||||||
return soup.get_text(strip=True, separator="\n")
|
return soup.get_text(strip=True, separator="\n")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
|
def extract_plaintext_entries(entry_to_file_map: dict[str, str]) -> List[Entry]:
|
||||||
"Convert each plaintext entries into a dictionary"
|
"Convert each plaintext entries into a dictionary"
|
||||||
entries = []
|
entries = []
|
||||||
for file, entry in entry_to_file_map.items():
|
for file, entry in entry_to_file_map.items():
|
||||||
|
@ -87,8 +87,3 @@ class PlaintextToEntries(TextToEntries):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_entries_to_jsonl(entries: List[Entry]):
|
|
||||||
"Convert each entry to JSON and collate as JSONL"
|
|
||||||
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from typing import Any, Callable, List, Set, Tuple
|
from typing import Any, Callable, List, Set, Tuple
|
||||||
|
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
||||||
|
@ -34,6 +36,27 @@ class TextToEntries(ABC):
|
||||||
def hash_func(key: str) -> Callable:
|
def hash_func(key: str) -> Callable:
|
||||||
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest()
|
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def remove_long_words(text: str, max_word_length: int = 500) -> str:
|
||||||
|
"Remove words longer than max_word_length from text."
|
||||||
|
# Split the string by words, keeping the delimiters
|
||||||
|
splits = re.split(r"(\s+)", text) + [""]
|
||||||
|
words_with_delimiters = list(zip(splits[::2], splits[1::2]))
|
||||||
|
|
||||||
|
# Filter out long words while preserving delimiters in text
|
||||||
|
filtered_text = [
|
||||||
|
f"{word}{delimiter}"
|
||||||
|
for word, delimiter in words_with_delimiters
|
||||||
|
if not word.strip() or len(word.strip()) <= max_word_length
|
||||||
|
]
|
||||||
|
|
||||||
|
return "".join(filtered_text)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tokenizer(text: str) -> List[str]:
|
||||||
|
"Tokenize text into words."
|
||||||
|
return text.split()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def split_entries_by_max_tokens(
|
def split_entries_by_max_tokens(
|
||||||
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500
|
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500
|
||||||
|
@ -44,24 +67,30 @@ class TextToEntries(ABC):
|
||||||
if is_none_or_empty(entry.compiled):
|
if is_none_or_empty(entry.compiled):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Split entry into words
|
# Split entry into chunks of max_tokens
|
||||||
compiled_entry_words = [word for word in entry.compiled.split(" ") if word != ""]
|
# Use chunking preference order: paragraphs > sentences > words > characters
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
|
chunk_size=max_tokens,
|
||||||
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
|
separators=["\n\n", "\n", "!", "?", ".", " ", "\t", ""],
|
||||||
|
keep_separator=True,
|
||||||
|
length_function=lambda chunk: len(TextToEntries.tokenizer(chunk)),
|
||||||
|
chunk_overlap=0,
|
||||||
|
)
|
||||||
|
chunked_entry_chunks = text_splitter.split_text(entry.compiled)
|
||||||
corpus_id = uuid.uuid4()
|
corpus_id = uuid.uuid4()
|
||||||
|
|
||||||
# Split entry into chunks of max tokens
|
# Create heading prefixed entry from each chunk
|
||||||
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
|
for chunk_index, compiled_entry_chunk in enumerate(chunked_entry_chunks):
|
||||||
compiled_entry_words_chunk = compiled_entry_words[chunk_index : chunk_index + max_tokens]
|
|
||||||
compiled_entry_chunk = " ".join(compiled_entry_words_chunk)
|
|
||||||
|
|
||||||
# Prepend heading to all other chunks, the first chunk already has heading from original entry
|
# Prepend heading to all other chunks, the first chunk already has heading from original entry
|
||||||
if chunk_index > 0:
|
if chunk_index > 0 and entry.heading:
|
||||||
# Snip heading to avoid crossing max_tokens limit
|
# Snip heading to avoid crossing max_tokens limit
|
||||||
# Keep last 100 characters of heading as entry heading more important than filename
|
# Keep last 100 characters of heading as entry heading more important than filename
|
||||||
snipped_heading = entry.heading[-100:]
|
snipped_heading = entry.heading[-100:]
|
||||||
compiled_entry_chunk = f"{snipped_heading}.\n{compiled_entry_chunk}"
|
# Prepend snipped heading
|
||||||
|
compiled_entry_chunk = f"{snipped_heading}\n{compiled_entry_chunk}"
|
||||||
|
|
||||||
|
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
|
||||||
|
compiled_entry_chunk = TextToEntries.remove_long_words(compiled_entry_chunk, max_word_length)
|
||||||
|
|
||||||
# Clean entry of unwanted characters like \0 character
|
# Clean entry of unwanted characters like \0 character
|
||||||
compiled_entry_chunk = TextToEntries.clean_field(compiled_entry_chunk)
|
compiled_entry_chunk = TextToEntries.clean_field(compiled_entry_chunk)
|
||||||
|
@ -160,7 +189,7 @@ class TextToEntries(ABC):
|
||||||
new_dates = []
|
new_dates = []
|
||||||
with timer("Indexed dates from added entries in", logger):
|
with timer("Indexed dates from added entries in", logger):
|
||||||
for added_entry in added_entries:
|
for added_entry in added_entries:
|
||||||
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.raw), repeat(added_entry))
|
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.compiled), repeat(added_entry))
|
||||||
dates_to_create = [
|
dates_to_create = [
|
||||||
EntryDates(date=date, entry=added_entry)
|
EntryDates(date=date, entry=added_entry)
|
||||||
for date, added_entry in dates_in_entries
|
for date, added_entry in dates_in_entries
|
||||||
|
@ -244,11 +273,6 @@ class TextToEntries(ABC):
|
||||||
|
|
||||||
return entries_with_ids
|
return entries_with_ids
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_text_maps_to_jsonl(entries: List[Entry]) -> str:
|
|
||||||
# Convert each entry to JSON and write to JSONL file
|
|
||||||
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def clean_field(field: str) -> str:
|
def clean_field(field: str) -> str:
|
||||||
return field.replace("\0", "") if not is_none_or_empty(field) else ""
|
return field.replace("\0", "") if not is_none_or_empty(field) else ""
|
||||||
|
|
|
@ -489,7 +489,7 @@ async def chat(
|
||||||
common: CommonQueryParams,
|
common: CommonQueryParams,
|
||||||
q: str,
|
q: str,
|
||||||
n: Optional[int] = 5,
|
n: Optional[int] = 5,
|
||||||
d: Optional[float] = 0.18,
|
d: Optional[float] = 0.22,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
title: Optional[str] = None,
|
title: Optional[str] = None,
|
||||||
conversation_id: Optional[int] = None,
|
conversation_id: Optional[int] = None,
|
||||||
|
|
|
@ -306,7 +306,7 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa
|
||||||
user_query = quote("How to git install application?")
|
user_query = quote("How to git install application?")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
|
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.22", headers=headers)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
@ -325,7 +325,7 @@ def test_notes_search_no_results(client, search_config: SearchConfig, sample_org
|
||||||
user_query = quote("How to find my goat?")
|
user_query = quote("How to find my goat?")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
|
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.22", headers=headers)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
@ -409,7 +409,7 @@ def test_notes_search_requires_parent_context(
|
||||||
user_query = quote("Install Khoj on Emacs")
|
user_query = quote("Install Khoj on Emacs")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
|
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.22", headers=headers)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -7,8 +6,8 @@ from khoj.utils.fs_syncer import get_markdown_files
|
||||||
from khoj.utils.rawconfig import TextContentConfig
|
from khoj.utils.rawconfig import TextContentConfig
|
||||||
|
|
||||||
|
|
||||||
def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
|
def test_extract_markdown_with_no_headings(tmp_path):
|
||||||
"Convert files with no heading to jsonl."
|
"Convert markdown file with no heading to entry format."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f"""
|
entry = f"""
|
||||||
- Bullet point 1
|
- Bullet point 1
|
||||||
|
@ -17,30 +16,24 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
|
||||||
data = {
|
data = {
|
||||||
f"{tmp_path}": entry,
|
f"{tmp_path}": entry,
|
||||||
}
|
}
|
||||||
expected_heading = f"# {tmp_path.stem}"
|
expected_heading = f"# {tmp_path}"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Extract Entries from specified Markdown files
|
# Extract Entries from specified Markdown files
|
||||||
entry_nodes, file_to_entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
|
||||||
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(
|
|
||||||
MarkdownToEntries.convert_markdown_entries_to_maps(entry_nodes, file_to_entries)
|
|
||||||
)
|
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(jsonl_data) == 1
|
assert len(entries) == 1
|
||||||
# Ensure raw entry with no headings do not get heading prefix prepended
|
# Ensure raw entry with no headings do not get heading prefix prepended
|
||||||
assert not jsonl_data[0]["raw"].startswith("#")
|
assert not entries[0].raw.startswith("#")
|
||||||
# Ensure compiled entry has filename prepended as top level heading
|
# Ensure compiled entry has filename prepended as top level heading
|
||||||
assert expected_heading in jsonl_data[0]["compiled"]
|
assert entries[0].compiled.startswith(expected_heading)
|
||||||
# Ensure compiled entry also includes the file name
|
# Ensure compiled entry also includes the file name
|
||||||
assert str(tmp_path) in jsonl_data[0]["compiled"]
|
assert str(tmp_path) in entries[0].compiled
|
||||||
|
|
||||||
|
|
||||||
def test_single_markdown_entry_to_jsonl(tmp_path):
|
def test_extract_single_markdown_entry(tmp_path):
|
||||||
"Convert markdown entry from single file to jsonl."
|
"Convert markdown from single file to entry format."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f"""### Heading
|
entry = f"""### Heading
|
||||||
\t\r
|
\t\r
|
||||||
|
@ -52,20 +45,14 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Extract Entries from specified Markdown files
|
# Extract Entries from specified Markdown files
|
||||||
entries, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
|
||||||
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(
|
|
||||||
MarkdownToEntries.convert_markdown_entries_to_maps(entries, entry_to_file_map)
|
|
||||||
)
|
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(jsonl_data) == 1
|
assert len(entries) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_markdown_entries_to_jsonl(tmp_path):
|
def test_extract_multiple_markdown_entries(tmp_path):
|
||||||
"Convert multiple markdown entries from single file to jsonl."
|
"Convert multiple markdown from single file to entry format."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f"""
|
entry = f"""
|
||||||
### Heading 1
|
### Heading 1
|
||||||
|
@ -81,19 +68,139 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Extract Entries from specified Markdown files
|
# Extract Entries from specified Markdown files
|
||||||
entry_strings, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||||
entries = MarkdownToEntries.convert_markdown_entries_to_maps(entry_strings, entry_to_file_map)
|
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
|
||||||
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(entries)
|
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(jsonl_data) == 2
|
assert len(entries) == 2
|
||||||
# Ensure entry compiled strings include the markdown files they originate from
|
# Ensure entry compiled strings include the markdown files they originate from
|
||||||
assert all([tmp_path.stem in entry.compiled for entry in entries])
|
assert all([tmp_path.stem in entry.compiled for entry in entries])
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_entries_with_different_level_headings(tmp_path):
|
||||||
|
"Extract markdown entries with different level headings."
|
||||||
|
# Arrange
|
||||||
|
entry = f"""
|
||||||
|
# Heading 1
|
||||||
|
## Sub-Heading 1.1
|
||||||
|
# Heading 2
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
f"{tmp_path}": entry,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Extract Entries from specified Markdown files
|
||||||
|
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(entries) == 2
|
||||||
|
assert entries[0].raw == "# Heading 1\n## Sub-Heading 1.1", "Ensure entry includes heading ancestory"
|
||||||
|
assert entries[1].raw == "# Heading 2\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
|
||||||
|
"Extract markdown entries when deeper child level before shallower child level."
|
||||||
|
# Arrange
|
||||||
|
entry = f"""
|
||||||
|
# Heading 1
|
||||||
|
#### Sub-Heading 1.1
|
||||||
|
## Sub-Heading 1.2
|
||||||
|
# Heading 2
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
f"{tmp_path}": entry,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Extract Entries from specified Markdown files
|
||||||
|
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(entries) == 3
|
||||||
|
assert entries[0].raw == "# Heading 1\n#### Sub-Heading 1.1", "Ensure entry includes heading ancestory"
|
||||||
|
assert entries[1].raw == "# Heading 1\n## Sub-Heading 1.2", "Ensure entry includes heading ancestory"
|
||||||
|
assert entries[2].raw == "# Heading 2\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_entries_with_text_before_headings(tmp_path):
|
||||||
|
"Extract markdown entries with some text before any headings."
|
||||||
|
# Arrange
|
||||||
|
entry = f"""
|
||||||
|
Text before headings
|
||||||
|
# Heading 1
|
||||||
|
body line 1
|
||||||
|
## Heading 2
|
||||||
|
body line 2
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
f"{tmp_path}": entry,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Extract Entries from specified Markdown files
|
||||||
|
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(entries) == 3
|
||||||
|
assert entries[0].raw == "\nText before headings"
|
||||||
|
assert entries[1].raw == "# Heading 1\nbody line 1"
|
||||||
|
assert entries[2].raw == "# Heading 1\n## Heading 2\nbody line 2\n", "Ensure raw entry includes heading ancestory"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_markdown_file_into_single_entry_if_small(tmp_path):
|
||||||
|
"Parse markdown file into single entry if it fits within the token limits."
|
||||||
|
# Arrange
|
||||||
|
entry = f"""
|
||||||
|
# Heading 1
|
||||||
|
body line 1
|
||||||
|
## Subheading 1.1
|
||||||
|
body line 1.1
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
f"{tmp_path}": entry,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Extract Entries from specified Markdown files
|
||||||
|
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert entries[0].raw == entry
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_markdown_entry_with_children_as_single_entry_if_small(tmp_path):
|
||||||
|
"Parse markdown entry with child headings as single entry if it fits within the tokens limits."
|
||||||
|
# Arrange
|
||||||
|
entry = f"""
|
||||||
|
# Heading 1
|
||||||
|
body line 1
|
||||||
|
## Subheading 1.1
|
||||||
|
body line 1.1
|
||||||
|
# Heading 2
|
||||||
|
body line 2
|
||||||
|
## Subheading 2.1
|
||||||
|
longer body line 2.1
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
f"{tmp_path}": entry,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Extract Entries from specified Markdown files
|
||||||
|
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(entries) == 3
|
||||||
|
assert (
|
||||||
|
entries[0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1"
|
||||||
|
), "First entry includes children headings"
|
||||||
|
assert entries[1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings"
|
||||||
|
assert (
|
||||||
|
entries[2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n"
|
||||||
|
), "Third entry is second entries child heading"
|
||||||
|
|
||||||
|
|
||||||
def test_get_markdown_files(tmp_path):
|
def test_get_markdown_files(tmp_path):
|
||||||
"Ensure Markdown files specified via input-filter, input-files extracted"
|
"Ensure Markdown files specified via input-filter, input-files extracted"
|
||||||
# Arrange
|
# Arrange
|
||||||
|
@ -131,27 +238,6 @@ def test_get_markdown_files(tmp_path):
|
||||||
assert set(extracted_org_files.keys()) == expected_files
|
assert set(extracted_org_files.keys()) == expected_files
|
||||||
|
|
||||||
|
|
||||||
def test_extract_entries_with_different_level_headings(tmp_path):
|
|
||||||
"Extract markdown entries with different level headings."
|
|
||||||
# Arrange
|
|
||||||
entry = f"""
|
|
||||||
# Heading 1
|
|
||||||
## Heading 2
|
|
||||||
"""
|
|
||||||
data = {
|
|
||||||
f"{tmp_path}": entry,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Act
|
|
||||||
# Extract Entries from specified Markdown files
|
|
||||||
entries, _ = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert len(entries) == 2
|
|
||||||
assert entries[0] == "# Heading 1"
|
|
||||||
assert entries[1] == "## Heading 2"
|
|
||||||
|
|
||||||
|
|
||||||
# Helper Functions
|
# Helper Functions
|
||||||
def create_file(tmp_path: Path, entry=None, filename="test.md"):
|
def create_file(tmp_path: Path, entry=None, filename="test.md"):
|
||||||
markdown_file = tmp_path / filename
|
markdown_file = tmp_path / filename
|
||||||
|
|
|
@ -56,7 +56,7 @@ def test_index_update_with_user2_inaccessible_user1(client, api_user2: KhojApiUs
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert update_response.status_code == 200
|
assert update_response.status_code == 200
|
||||||
assert len(results) == 5
|
assert len(results) == 3
|
||||||
for result in results:
|
for result in results:
|
||||||
assert result["additional"]["file"] not in source_file_symbol
|
assert result["additional"]["file"] not in source_file_symbol
|
||||||
|
|
||||||
|
|
|
@ -470,10 +470,6 @@ async def test_websearch_with_operators(chat_client):
|
||||||
["site:reddit.com" in response for response in responses]
|
["site:reddit.com" in response for response in responses]
|
||||||
), "Expected a search query to include site:reddit.com but got: " + str(responses)
|
), "Expected a search query to include site:reddit.com but got: " + str(responses)
|
||||||
|
|
||||||
assert any(
|
|
||||||
["after:2024/04/01" in response for response in responses]
|
|
||||||
), "Expected a search query to include after:2024/04/01 but got: " + str(responses)
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
||||||
from khoj.processor.content.text_to_entries import TextToEntries
|
from khoj.processor.content.text_to_entries import TextToEntries
|
||||||
|
@ -8,7 +8,7 @@ from khoj.utils.helpers import is_none_or_empty
|
||||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||||
|
|
||||||
|
|
||||||
def test_configure_heading_entry_to_jsonl(tmp_path):
|
def test_configure_indexing_heading_only_entries(tmp_path):
|
||||||
"""Ensure entries with empty body are ignored, unless explicitly configured to index heading entries.
|
"""Ensure entries with empty body are ignored, unless explicitly configured to index heading entries.
|
||||||
Property drawers not considered Body. Ignore control characters for evaluating if Body empty."""
|
Property drawers not considered Body. Ignore control characters for evaluating if Body empty."""
|
||||||
# Arrange
|
# Arrange
|
||||||
|
@ -26,24 +26,21 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
|
||||||
for index_heading_entries in [True, False]:
|
for index_heading_entries in [True, False]:
|
||||||
# Act
|
# Act
|
||||||
# Extract entries into jsonl from specified Org files
|
# Extract entries into jsonl from specified Org files
|
||||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
|
entries = OrgToEntries.extract_org_entries(
|
||||||
OrgToEntries.convert_org_nodes_to_entries(
|
org_files=data, index_heading_entries=index_heading_entries, max_tokens=3
|
||||||
*OrgToEntries.extract_org_entries(org_files=data), index_heading_entries=index_heading_entries
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
if index_heading_entries:
|
if index_heading_entries:
|
||||||
# Entry with empty body indexed when index_heading_entries set to True
|
# Entry with empty body indexed when index_heading_entries set to True
|
||||||
assert len(jsonl_data) == 1
|
assert len(entries) == 1
|
||||||
else:
|
else:
|
||||||
# Entry with empty body ignored when index_heading_entries set to False
|
# Entry with empty body ignored when index_heading_entries set to False
|
||||||
assert is_none_or_empty(jsonl_data)
|
assert is_none_or_empty(entries)
|
||||||
|
|
||||||
|
|
||||||
def test_entry_split_when_exceeds_max_words():
|
def test_entry_split_when_exceeds_max_tokens():
|
||||||
"Ensure entries with compiled words exceeding max_words are split."
|
"Ensure entries with compiled words exceeding max_tokens are split."
|
||||||
# Arrange
|
# Arrange
|
||||||
tmp_path = "/tmp/test.org"
|
tmp_path = "/tmp/test.org"
|
||||||
entry = f"""*** Heading
|
entry = f"""*** Heading
|
||||||
|
@ -57,29 +54,26 @@ def test_entry_split_when_exceeds_max_words():
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Extract Entries from specified Org files
|
# Extract Entries from specified Org files
|
||||||
entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data)
|
entries = OrgToEntries.extract_org_entries(org_files=data)
|
||||||
|
|
||||||
# Split each entry from specified Org files by max words
|
# Split each entry from specified Org files by max tokens
|
||||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
|
entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=6)
|
||||||
TextToEntries.split_entries_by_max_tokens(
|
|
||||||
OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
|
|
||||||
)
|
|
||||||
)
|
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(jsonl_data) == 2
|
assert len(entries) == 2
|
||||||
# Ensure compiled entries split by max_words start with entry heading (for search context)
|
# Ensure compiled entries split by max tokens start with entry heading (for search context)
|
||||||
assert all([entry["compiled"].startswith(expected_heading) for entry in jsonl_data])
|
assert all([entry.compiled.startswith(expected_heading) for entry in entries])
|
||||||
|
|
||||||
|
|
||||||
def test_entry_split_drops_large_words():
|
def test_entry_split_drops_large_words():
|
||||||
"Ensure entries drops words larger than specified max word length from compiled version."
|
"Ensure entries drops words larger than specified max word length from compiled version."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry_text = f"""*** Heading
|
entry_text = f"""First Line
|
||||||
\t\r
|
dog=1\n\r\t
|
||||||
Body Line 1
|
cat=10
|
||||||
"""
|
car=4
|
||||||
|
book=2
|
||||||
|
"""
|
||||||
entry = Entry(raw=entry_text, compiled=entry_text)
|
entry = Entry(raw=entry_text, compiled=entry_text)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -87,11 +81,158 @@ def test_entry_split_drops_large_words():
|
||||||
processed_entry = TextToEntries.split_entries_by_max_tokens([entry], max_word_length=5)[0]
|
processed_entry = TextToEntries.split_entries_by_max_tokens([entry], max_word_length=5)[0]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# "Heading" dropped from compiled version because its over the set max word limit
|
# Ensure words larger than max word length are dropped
|
||||||
assert len(processed_entry.compiled.split()) == len(entry_text.split()) - 1
|
# Ensure newline characters are considered as word boundaries for splitting words. See #620
|
||||||
|
words_to_keep = ["First", "Line", "dog=1", "car=4"]
|
||||||
|
words_to_drop = ["cat=10", "book=2"]
|
||||||
|
assert all([word for word in words_to_keep if word in processed_entry.compiled])
|
||||||
|
assert not any([word for word in words_to_drop if word in processed_entry.compiled])
|
||||||
|
assert len(processed_entry.compiled.split()) == len(entry_text.split()) - 2
|
||||||
|
|
||||||
|
|
||||||
def test_entry_with_body_to_jsonl(tmp_path):
|
def test_parse_org_file_into_single_entry_if_small(tmp_path):
|
||||||
|
"Parse org file into single entry if it fits within the token limits."
|
||||||
|
# Arrange
|
||||||
|
original_entry = f"""
|
||||||
|
* Heading 1
|
||||||
|
body line 1
|
||||||
|
** Subheading 1.1
|
||||||
|
body line 1.1
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
f"{tmp_path}": original_entry,
|
||||||
|
}
|
||||||
|
expected_entry = f"""
|
||||||
|
* Heading 1
|
||||||
|
body line 1
|
||||||
|
|
||||||
|
** Subheading 1.1
|
||||||
|
body line 1.1
|
||||||
|
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Extract Entries from specified Org files
|
||||||
|
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12)
|
||||||
|
for entry in extracted_entries:
|
||||||
|
entry.raw = clean(entry.raw)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(extracted_entries) == 1
|
||||||
|
assert entry.raw == expected_entry
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_org_entry_with_children_as_single_entry_if_small(tmp_path):
|
||||||
|
"Parse org entry with child headings as single entry only if it fits within the tokens limits."
|
||||||
|
# Arrange
|
||||||
|
entry = f"""
|
||||||
|
* Heading 1
|
||||||
|
body line 1
|
||||||
|
** Subheading 1.1
|
||||||
|
body line 1.1
|
||||||
|
* Heading 2
|
||||||
|
body line 2
|
||||||
|
** Subheading 2.1
|
||||||
|
longer body line 2.1
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
f"{tmp_path}": entry,
|
||||||
|
}
|
||||||
|
first_expected_entry = f"""
|
||||||
|
* Path: {tmp_path}
|
||||||
|
** Heading 1.
|
||||||
|
body line 1
|
||||||
|
|
||||||
|
*** Subheading 1.1.
|
||||||
|
body line 1.1
|
||||||
|
|
||||||
|
""".lstrip()
|
||||||
|
second_expected_entry = f"""
|
||||||
|
* Path: {tmp_path}
|
||||||
|
** Heading 2.
|
||||||
|
body line 2
|
||||||
|
|
||||||
|
""".lstrip()
|
||||||
|
third_expected_entry = f"""
|
||||||
|
* Path: {tmp_path} / Heading 2
|
||||||
|
** Subheading 2.1.
|
||||||
|
longer body line 2.1
|
||||||
|
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Extract Entries from specified Org files
|
||||||
|
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(extracted_entries) == 3
|
||||||
|
assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings"
|
||||||
|
assert extracted_entries[1].compiled == second_expected_entry, "Second entry does not include children headings"
|
||||||
|
assert extracted_entries[2].compiled == third_expected_entry, "Third entry is second entries child heading"
|
||||||
|
|
||||||
|
|
||||||
|
def test_separate_sibling_org_entries_if_all_cannot_fit_in_token_limit(tmp_path):
|
||||||
|
"Parse org sibling entries as separate entries only if it fits within the tokens limits."
|
||||||
|
# Arrange
|
||||||
|
entry = f"""
|
||||||
|
* Heading 1
|
||||||
|
body line 1
|
||||||
|
** Subheading 1.1
|
||||||
|
body line 1.1
|
||||||
|
* Heading 2
|
||||||
|
body line 2
|
||||||
|
** Subheading 2.1
|
||||||
|
body line 2.1
|
||||||
|
* Heading 3
|
||||||
|
body line 3
|
||||||
|
** Subheading 3.1
|
||||||
|
body line 3.1
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
f"{tmp_path}": entry,
|
||||||
|
}
|
||||||
|
first_expected_entry = f"""
|
||||||
|
* Path: {tmp_path}
|
||||||
|
** Heading 1.
|
||||||
|
body line 1
|
||||||
|
|
||||||
|
*** Subheading 1.1.
|
||||||
|
body line 1.1
|
||||||
|
|
||||||
|
""".lstrip()
|
||||||
|
second_expected_entry = f"""
|
||||||
|
* Path: {tmp_path}
|
||||||
|
** Heading 2.
|
||||||
|
body line 2
|
||||||
|
|
||||||
|
*** Subheading 2.1.
|
||||||
|
body line 2.1
|
||||||
|
|
||||||
|
""".lstrip()
|
||||||
|
third_expected_entry = f"""
|
||||||
|
* Path: {tmp_path}
|
||||||
|
** Heading 3.
|
||||||
|
body line 3
|
||||||
|
|
||||||
|
*** Subheading 3.1.
|
||||||
|
body line 3.1
|
||||||
|
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Extract Entries from specified Org files
|
||||||
|
# Max tokens = 30 is in the middle of 2 entry (24 tokens) and 3 entry (36 tokens) tokens boundary
|
||||||
|
# Where each sibling entry contains 12 tokens per sibling entry * 3 entries = 36 tokens
|
||||||
|
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=30)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(extracted_entries) == 3
|
||||||
|
assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings"
|
||||||
|
assert extracted_entries[1].compiled == second_expected_entry, "Second entry includes children headings"
|
||||||
|
assert extracted_entries[2].compiled == third_expected_entry, "Third entry includes children headings"
|
||||||
|
|
||||||
|
|
||||||
|
def test_entry_with_body_to_entry(tmp_path):
|
||||||
"Ensure entries with valid body text are loaded."
|
"Ensure entries with valid body text are loaded."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f"""*** Heading
|
entry = f"""*** Heading
|
||||||
|
@ -107,19 +248,13 @@ def test_entry_with_body_to_jsonl(tmp_path):
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Extract Entries from specified Org files
|
# Extract Entries from specified Org files
|
||||||
entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data)
|
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
|
||||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
|
|
||||||
OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map)
|
|
||||||
)
|
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(jsonl_data) == 1
|
assert len(entries) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_file_with_entry_after_intro_text_to_jsonl(tmp_path):
|
def test_file_with_entry_after_intro_text_to_entry(tmp_path):
|
||||||
"Ensure intro text before any headings is indexed."
|
"Ensure intro text before any headings is indexed."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f"""
|
entry = f"""
|
||||||
|
@ -134,18 +269,13 @@ Intro text
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Extract Entries from specified Org files
|
# Extract Entries from specified Org files
|
||||||
entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data)
|
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
|
||||||
entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
|
||||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries)
|
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(jsonl_data) == 2
|
assert len(entries) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_file_with_no_headings_to_jsonl(tmp_path):
|
def test_file_with_no_headings_to_entry(tmp_path):
|
||||||
"Ensure files with no heading, only body text are loaded."
|
"Ensure files with no heading, only body text are loaded."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f"""
|
entry = f"""
|
||||||
|
@ -158,15 +288,10 @@ def test_file_with_no_headings_to_jsonl(tmp_path):
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Extract Entries from specified Org files
|
# Extract Entries from specified Org files
|
||||||
entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data)
|
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
|
||||||
entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
|
||||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries)
|
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(jsonl_data) == 1
|
assert len(entries) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_get_org_files(tmp_path):
|
def test_get_org_files(tmp_path):
|
||||||
|
@ -214,7 +339,8 @@ def test_extract_entries_with_different_level_headings(tmp_path):
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f"""
|
entry = f"""
|
||||||
* Heading 1
|
* Heading 1
|
||||||
** Heading 2
|
** Sub-Heading 1.1
|
||||||
|
* Heading 2
|
||||||
"""
|
"""
|
||||||
data = {
|
data = {
|
||||||
f"{tmp_path}": entry,
|
f"{tmp_path}": entry,
|
||||||
|
@ -222,12 +348,14 @@ def test_extract_entries_with_different_level_headings(tmp_path):
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Extract Entries from specified Org files
|
# Extract Entries from specified Org files
|
||||||
entries, _ = OrgToEntries.extract_org_entries(org_files=data)
|
entries = OrgToEntries.extract_org_entries(org_files=data, index_heading_entries=True, max_tokens=3)
|
||||||
|
for entry in entries:
|
||||||
|
entry.raw = clean(f"{entry.raw}")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 2
|
assert len(entries) == 2
|
||||||
assert f"{entries[0]}".startswith("* Heading 1")
|
assert entries[0].raw == "* Heading 1\n** Sub-Heading 1.1\n", "Ensure entry includes heading ancestory"
|
||||||
assert f"{entries[1]}".startswith("** Heading 2")
|
assert entries[1].raw == "* Heading 2\n"
|
||||||
|
|
||||||
|
|
||||||
# Helper Functions
|
# Helper Functions
|
||||||
|
@ -237,3 +365,8 @@ def create_file(tmp_path, entry=None, filename="test.org"):
|
||||||
if entry:
|
if entry:
|
||||||
org_file.write_text(entry)
|
org_file.write_text(entry)
|
||||||
return org_file
|
return org_file
|
||||||
|
|
||||||
|
|
||||||
|
def clean(entry):
|
||||||
|
"Remove properties from entry for easier comparison."
|
||||||
|
return re.sub(r"\n:PROPERTIES:(.*?):END:", "", entry, flags=re.DOTALL)
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
||||||
|
@ -15,16 +14,10 @@ def test_single_page_pdf_to_jsonl():
|
||||||
pdf_bytes = f.read()
|
pdf_bytes = f.read()
|
||||||
|
|
||||||
data = {"tests/data/pdf/singlepage.pdf": pdf_bytes}
|
data = {"tests/data/pdf/singlepage.pdf": pdf_bytes}
|
||||||
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||||
|
|
||||||
# Process Each Entry from All Pdf Files
|
|
||||||
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(
|
|
||||||
PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
|
|
||||||
)
|
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(jsonl_data) == 1
|
assert len(entries) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_multi_page_pdf_to_jsonl():
|
def test_multi_page_pdf_to_jsonl():
|
||||||
|
@ -35,16 +28,10 @@ def test_multi_page_pdf_to_jsonl():
|
||||||
pdf_bytes = f.read()
|
pdf_bytes = f.read()
|
||||||
|
|
||||||
data = {"tests/data/pdf/multipage.pdf": pdf_bytes}
|
data = {"tests/data/pdf/multipage.pdf": pdf_bytes}
|
||||||
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||||
|
|
||||||
# Process Each Entry from All Pdf Files
|
|
||||||
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(
|
|
||||||
PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
|
|
||||||
)
|
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(jsonl_data) == 6
|
assert len(entries) == 6
|
||||||
|
|
||||||
|
|
||||||
def test_ocr_page_pdf_to_jsonl():
|
def test_ocr_page_pdf_to_jsonl():
|
||||||
|
@ -55,10 +42,7 @@ def test_ocr_page_pdf_to_jsonl():
|
||||||
pdf_bytes = f.read()
|
pdf_bytes = f.read()
|
||||||
|
|
||||||
data = {"tests/data/pdf/ocr_samples.pdf": pdf_bytes}
|
data = {"tests/data/pdf/ocr_samples.pdf": pdf_bytes}
|
||||||
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||||
|
|
||||||
# Process Each Entry from All Pdf Files
|
|
||||||
entries = PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
|
|
||||||
|
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
assert "playing on a strip of marsh" in entries[0].raw
|
assert "playing on a strip of marsh" in entries[0].raw
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -11,10 +10,10 @@ from khoj.utils.rawconfig import TextContentConfig
|
||||||
def test_plaintext_file(tmp_path):
|
def test_plaintext_file(tmp_path):
|
||||||
"Convert files with no heading to jsonl."
|
"Convert files with no heading to jsonl."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f"""
|
raw_entry = f"""
|
||||||
Hi, I am a plaintext file and I have some plaintext words.
|
Hi, I am a plaintext file and I have some plaintext words.
|
||||||
"""
|
"""
|
||||||
plaintextfile = create_file(tmp_path, entry)
|
plaintextfile = create_file(tmp_path, raw_entry)
|
||||||
|
|
||||||
filename = plaintextfile.stem
|
filename = plaintextfile.stem
|
||||||
|
|
||||||
|
@ -22,25 +21,21 @@ def test_plaintext_file(tmp_path):
|
||||||
# Extract Entries from specified plaintext files
|
# Extract Entries from specified plaintext files
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
f"{plaintextfile}": entry,
|
f"{plaintextfile}": raw_entry,
|
||||||
}
|
}
|
||||||
|
|
||||||
maps = PlaintextToEntries.convert_plaintext_entries_to_maps(entry_to_file_map=data)
|
entries = PlaintextToEntries.extract_plaintext_entries(entry_to_file_map=data)
|
||||||
|
|
||||||
# Convert each entry.file to absolute path to make them JSON serializable
|
# Convert each entry.file to absolute path to make them JSON serializable
|
||||||
for map in maps:
|
for entry in entries:
|
||||||
map.file = str(Path(map.file).absolute())
|
entry.file = str(Path(entry.file).absolute())
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
|
||||||
jsonl_string = PlaintextToEntries.convert_entries_to_jsonl(maps)
|
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(jsonl_data) == 1
|
assert len(entries) == 1
|
||||||
# Ensure raw entry with no headings do not get heading prefix prepended
|
# Ensure raw entry with no headings do not get heading prefix prepended
|
||||||
assert not jsonl_data[0]["raw"].startswith("#")
|
assert not entries[0].raw.startswith("#")
|
||||||
# Ensure compiled entry has filename prepended as top level heading
|
# Ensure compiled entry has filename prepended as top level heading
|
||||||
assert jsonl_data[0]["compiled"] == f"{filename}\n{entry}"
|
assert entries[0].compiled == f"{filename}\n{raw_entry}"
|
||||||
|
|
||||||
|
|
||||||
def test_get_plaintext_files(tmp_path):
|
def test_get_plaintext_files(tmp_path):
|
||||||
|
@ -98,11 +93,11 @@ def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
|
||||||
extracted_plaintext_files = get_plaintext_files(config=config)
|
extracted_plaintext_files = get_plaintext_files(config=config)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
maps = PlaintextToEntries.convert_plaintext_entries_to_maps(extracted_plaintext_files)
|
entries = PlaintextToEntries.extract_plaintext_entries(extracted_plaintext_files)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(maps) == 1
|
assert len(entries) == 1
|
||||||
assert "<div>" not in maps[0].raw
|
assert "<div>" not in entries[0].raw
|
||||||
|
|
||||||
|
|
||||||
# Helper Functions
|
# Helper Functions
|
||||||
|
|
|
@ -57,18 +57,21 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_text_search_setup_with_empty_file_creates_no_entries(
|
def test_text_search_setup_with_empty_file_creates_no_entries(
|
||||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser
|
||||||
):
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
|
existing_entries = Entry.objects.filter(user=default_user).count()
|
||||||
data = get_org_files(org_config_with_only_new_file)
|
data = get_org_files(org_config_with_only_new_file)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Generate notes embeddings during asymmetric setup
|
# Generate notes embeddings during asymmetric setup
|
||||||
with caplog.at_level(logging.INFO):
|
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert "Deleted 8 entries. Created 0 new entries for user " in caplog.records[-1].message
|
updated_entries = Entry.objects.filter(user=default_user).count()
|
||||||
|
|
||||||
|
assert existing_entries == 2
|
||||||
|
assert updated_entries == 0
|
||||||
verify_embeddings(0, default_user)
|
verify_embeddings(0, default_user)
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,6 +81,7 @@ def test_text_indexer_deletes_embedding_before_regenerate(
|
||||||
content_config: ContentConfig, default_user: KhojUser, caplog
|
content_config: ContentConfig, default_user: KhojUser, caplog
|
||||||
):
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
|
existing_entries = Entry.objects.filter(user=default_user).count()
|
||||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||||
data = get_org_files(org_config)
|
data = get_org_files(org_config)
|
||||||
|
|
||||||
|
@ -87,30 +91,18 @@ def test_text_indexer_deletes_embedding_before_regenerate(
|
||||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
updated_entries = Entry.objects.filter(user=default_user).count()
|
||||||
|
assert existing_entries == 2
|
||||||
|
assert updated_entries == 2
|
||||||
assert "Deleting all entries for file type org" in caplog.text
|
assert "Deleting all entries for file type org" in caplog.text
|
||||||
assert "Deleted 8 entries. Created 13 new entries for user " in caplog.records[-1].message
|
assert "Deleted 2 entries. Created 2 new entries for user " in caplog.records[-1].message
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
|
||||||
@pytest.mark.django_db
|
|
||||||
def test_text_search_setup_batch_processes(content_config: ContentConfig, default_user: KhojUser, caplog):
|
|
||||||
# Arrange
|
|
||||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
|
||||||
data = get_org_files(org_config)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
# Generate notes embeddings during asymmetric setup
|
|
||||||
with caplog.at_level(logging.DEBUG):
|
|
||||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "Deleted 8 entries. Created 13 new entries for user " in caplog.records[-1].message
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
|
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
|
||||||
# Arrange
|
# Arrange
|
||||||
|
existing_entries = Entry.objects.filter(user=default_user)
|
||||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||||
data = get_org_files(org_config)
|
data = get_org_files(org_config)
|
||||||
|
|
||||||
|
@ -127,6 +119,10 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
|
||||||
final_logs = caplog.text
|
final_logs = caplog.text
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
updated_entries = Entry.objects.filter(user=default_user)
|
||||||
|
for entry in updated_entries:
|
||||||
|
assert entry in existing_entries
|
||||||
|
assert len(existing_entries) == len(updated_entries)
|
||||||
assert "Deleting all entries for file type org" in initial_logs
|
assert "Deleting all entries for file type org" in initial_logs
|
||||||
assert "Deleting all entries for file type org" not in final_logs
|
assert "Deleting all entries for file type org" not in final_logs
|
||||||
|
|
||||||
|
@ -192,7 +188,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert (
|
assert (
|
||||||
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
|
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
|
||||||
), "new entry not split by max tokens"
|
), "new entry not split by max tokens"
|
||||||
|
|
||||||
|
|
||||||
|
@ -250,16 +246,15 @@ conda activate khoj
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert (
|
assert (
|
||||||
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
|
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
|
||||||
), "new entry not split by max tokens"
|
), "new entry not split by max tokens"
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_regenerate_index_with_new_entry(
|
def test_regenerate_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser):
|
||||||
content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog
|
|
||||||
):
|
|
||||||
# Arrange
|
# Arrange
|
||||||
|
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||||
initial_data = get_org_files(org_config)
|
initial_data = get_org_files(org_config)
|
||||||
|
|
||||||
|
@ -271,28 +266,34 @@ def test_regenerate_index_with_new_entry(
|
||||||
final_data = get_org_files(org_config)
|
final_data = get_org_files(org_config)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with caplog.at_level(logging.INFO):
|
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||||
initial_logs = caplog.text
|
|
||||||
caplog.clear() # Clear logs
|
|
||||||
|
|
||||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||||
with caplog.at_level(logging.INFO):
|
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
|
||||||
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
|
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||||
final_logs = caplog.text
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert "Deleted 8 entries. Created 13 new entries for user " in initial_logs
|
for entry in updated_entries1:
|
||||||
assert "Deleted 13 entries. Created 14 new entries for user " in final_logs
|
assert entry in updated_entries2
|
||||||
verify_embeddings(14, default_user)
|
|
||||||
|
assert not any([new_org_file.name in entry for entry in updated_entries1])
|
||||||
|
assert not any([new_org_file.name in entry for entry in existing_entries])
|
||||||
|
assert any([new_org_file.name in entry for entry in updated_entries2])
|
||||||
|
|
||||||
|
assert any(
|
||||||
|
["Saw a super cute video of a chihuahua doing the Tango on Youtube" in entry for entry in updated_entries2]
|
||||||
|
)
|
||||||
|
verify_embeddings(3, default_user)
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_update_index_with_duplicate_entries_in_stable_order(
|
def test_update_index_with_duplicate_entries_in_stable_order(
|
||||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser
|
||||||
):
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
|
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||||
|
|
||||||
# Insert org-mode entries with same compiled form into new org file
|
# Insert org-mode entries with same compiled form into new org file
|
||||||
|
@ -304,30 +305,33 @@ def test_update_index_with_duplicate_entries_in_stable_order(
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# generate embeddings, entries, notes model from scratch after adding new org-mode file
|
# generate embeddings, entries, notes model from scratch after adding new org-mode file
|
||||||
with caplog.at_level(logging.INFO):
|
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||||
initial_logs = caplog.text
|
|
||||||
caplog.clear() # Clear logs
|
|
||||||
|
|
||||||
data = get_org_files(org_config_with_only_new_file)
|
data = get_org_files(org_config_with_only_new_file)
|
||||||
|
|
||||||
# update embeddings, entries, notes model with no new changes
|
# update embeddings, entries, notes model with no new changes
|
||||||
with caplog.at_level(logging.INFO):
|
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||||
final_logs = caplog.text
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify only 1 entry added even if there are multiple duplicate entries
|
# verify only 1 entry added even if there are multiple duplicate entries
|
||||||
assert "Deleted 8 entries. Created 1 new entries for user " in initial_logs
|
for entry in existing_entries:
|
||||||
assert "Deleted 0 entries. Created 0 new entries for user " in final_logs
|
assert entry not in updated_entries1
|
||||||
|
|
||||||
|
for entry in updated_entries1:
|
||||||
|
assert entry in updated_entries2
|
||||||
|
|
||||||
|
assert len(existing_entries) == 2
|
||||||
|
assert len(updated_entries1) == len(updated_entries2)
|
||||||
verify_embeddings(1, default_user)
|
verify_embeddings(1, default_user)
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
|
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
|
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||||
|
|
||||||
# Insert org-mode entries with same compiled form into new org file
|
# Insert org-mode entries with same compiled form into new org file
|
||||||
|
@ -344,33 +348,34 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||||
with caplog.at_level(logging.INFO):
|
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||||
initial_logs = caplog.text
|
|
||||||
caplog.clear() # Clear logs
|
|
||||||
|
|
||||||
with caplog.at_level(logging.INFO):
|
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
|
||||||
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
|
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||||
final_logs = caplog.text
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify only 1 entry added even if there are multiple duplicate entries
|
for entry in existing_entries:
|
||||||
assert "Deleted 8 entries. Created 2 new entries for user " in initial_logs
|
assert entry not in updated_entries1
|
||||||
assert "Deleted 1 entries. Created 0 new entries for user " in final_logs
|
|
||||||
|
# verify the entry in updated_entries2 is a subset of updated_entries1
|
||||||
|
for entry in updated_entries1:
|
||||||
|
assert entry not in updated_entries2
|
||||||
|
|
||||||
|
for entry in updated_entries2:
|
||||||
|
assert entry in updated_entries1[0]
|
||||||
|
|
||||||
verify_embeddings(1, default_user)
|
verify_embeddings(1, default_user)
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog):
|
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
|
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||||
data = get_org_files(org_config)
|
data = get_org_files(org_config)
|
||||||
with caplog.at_level(logging.INFO):
|
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
|
||||||
initial_logs = caplog.text
|
|
||||||
caplog.clear() # Clear logs
|
|
||||||
|
|
||||||
# append org-mode entry to first org input file in config
|
# append org-mode entry to first org input file in config
|
||||||
with open(new_org_file, "w") as f:
|
with open(new_org_file, "w") as f:
|
||||||
|
@ -381,14 +386,14 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# update embeddings, entries with the newly added note
|
# update embeddings, entries with the newly added note
|
||||||
with caplog.at_level(logging.INFO):
|
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||||
final_logs = caplog.text
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert "Deleted 8 entries. Created 13 new entries for user " in initial_logs
|
for entry in existing_entries:
|
||||||
assert "Deleted 0 entries. Created 1 new entries for user " in final_logs
|
assert entry not in updated_entries1
|
||||||
verify_embeddings(14, default_user)
|
assert len(updated_entries1) == len(existing_entries) + 1
|
||||||
|
verify_embeddings(3, default_user)
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
|
Loading…
Add table
Reference in a new issue