mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Update Text Chunking Strategy to Improve Search Context (#645)
## Major - Parse markdown, org parent entries as single entry if fit within max tokens - Parse a file as single entry if it fits with max token limits - Add parent heading ancestry to extracted markdown entries for context - Chunk text in preference order of para, sentence, word, character ## Minor - Create wrapper function to get entries from org, md, pdf & text files - Remove unused Entry to Jsonl converter from text to entry class, tests - Dedupe code by using single func to process an org file into entries Resolves #620
This commit is contained in:
commit
11ce3e2268
15 changed files with 704 additions and 393 deletions
|
@ -1,14 +1,13 @@
|
|||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import urllib3
|
||||
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
@ -31,15 +30,14 @@ class MarkdownToEntries(TextToEntries):
|
|||
else:
|
||||
deletion_file_names = None
|
||||
|
||||
max_tokens = 256
|
||||
# Extract Entries from specified Markdown files
|
||||
with timer("Parse entries from Markdown files into dictionaries", logger):
|
||||
current_entries = MarkdownToEntries.convert_markdown_entries_to_maps(
|
||||
*MarkdownToEntries.extract_markdown_entries(files)
|
||||
)
|
||||
with timer("Extract entries from specified Markdown files", logger):
|
||||
current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens)
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
|
@ -57,48 +55,84 @@ class MarkdownToEntries(TextToEntries):
|
|||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def extract_markdown_entries(markdown_files):
|
||||
def extract_markdown_entries(markdown_files, max_tokens=256) -> List[Entry]:
|
||||
"Extract entries by heading from specified Markdown files"
|
||||
|
||||
# Regex to extract Markdown Entries by Heading
|
||||
|
||||
entries = []
|
||||
entry_to_file_map = []
|
||||
entries: List[str] = []
|
||||
entry_to_file_map: List[Tuple[str, str]] = []
|
||||
for markdown_file in markdown_files:
|
||||
try:
|
||||
markdown_content = markdown_files[markdown_file]
|
||||
entries, entry_to_file_map = MarkdownToEntries.process_single_markdown_file(
|
||||
markdown_content, markdown_file, entries, entry_to_file_map
|
||||
markdown_content, markdown_file, entries, entry_to_file_map, max_tokens
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to process file: {markdown_file}. This file will not be indexed.")
|
||||
logger.warning(e, exc_info=True)
|
||||
logger.error(
|
||||
f"Unable to process file: {markdown_file}. This file will not be indexed.\n{e}", exc_info=True
|
||||
)
|
||||
|
||||
return entries, dict(entry_to_file_map)
|
||||
return MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map))
|
||||
|
||||
@staticmethod
|
||||
def process_single_markdown_file(
|
||||
markdown_content: str, markdown_file: Path, entries: List, entry_to_file_map: List
|
||||
):
|
||||
markdown_heading_regex = r"^#"
|
||||
markdown_content: str,
|
||||
markdown_file: str,
|
||||
entries: List[str],
|
||||
entry_to_file_map: List[Tuple[str, str]],
|
||||
max_tokens=256,
|
||||
ancestry: Dict[int, str] = {},
|
||||
) -> Tuple[List[str], List[Tuple[str, str]]]:
|
||||
# Prepend the markdown section's heading ancestry
|
||||
ancestry_string = "\n".join([f"{'#' * key} {ancestry[key]}" for key in sorted(ancestry.keys())])
|
||||
markdown_content_with_ancestry = f"{ancestry_string}{markdown_content}"
|
||||
|
||||
markdown_entries_per_file = []
|
||||
any_headings = re.search(markdown_heading_regex, markdown_content, flags=re.MULTILINE)
|
||||
for entry in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE):
|
||||
# Add heading level as the regex split removed it from entries with headings
|
||||
prefix = "#" if entry.startswith("#") else "# " if any_headings else ""
|
||||
stripped_entry = entry.strip(empty_escape_sequences)
|
||||
if stripped_entry != "":
|
||||
markdown_entries_per_file.append(f"{prefix}{stripped_entry}")
|
||||
# If content is small or content has no children headings, save it as a single entry
|
||||
if len(TextToEntries.tokenizer(markdown_content_with_ancestry)) <= max_tokens or not re.search(
|
||||
rf"^#{{{len(ancestry)+1},}}\s", markdown_content, flags=re.MULTILINE
|
||||
):
|
||||
entry_to_file_map += [(markdown_content_with_ancestry, markdown_file)]
|
||||
entries.extend([markdown_content_with_ancestry])
|
||||
return entries, entry_to_file_map
|
||||
|
||||
# Split by next heading level present in the entry
|
||||
next_heading_level = len(ancestry)
|
||||
sections: List[str] = []
|
||||
while len(sections) < 2:
|
||||
next_heading_level += 1
|
||||
sections = re.split(rf"(\n|^)(?=[#]{{{next_heading_level}}} .+\n?)", markdown_content, flags=re.MULTILINE)
|
||||
|
||||
for section in sections:
|
||||
# Skip empty sections
|
||||
if section.strip() == "":
|
||||
continue
|
||||
|
||||
# Extract the section body and (when present) the heading
|
||||
current_ancestry = ancestry.copy()
|
||||
first_line = [line for line in section.split("\n") if line.strip() != ""][0]
|
||||
if re.search(rf"^#{{{next_heading_level}}} ", first_line):
|
||||
# Extract the section body without the heading
|
||||
current_section_body = "\n".join(section.split(first_line)[1:])
|
||||
# Parse the section heading into current section ancestry
|
||||
current_section_title = first_line[next_heading_level:].strip()
|
||||
current_ancestry[next_heading_level] = current_section_title
|
||||
else:
|
||||
current_section_body = section
|
||||
|
||||
# Recurse down children of the current entry
|
||||
MarkdownToEntries.process_single_markdown_file(
|
||||
current_section_body,
|
||||
markdown_file,
|
||||
entries,
|
||||
entry_to_file_map,
|
||||
max_tokens,
|
||||
current_ancestry,
|
||||
)
|
||||
|
||||
entry_to_file_map += zip(markdown_entries_per_file, [markdown_file] * len(markdown_entries_per_file))
|
||||
entries.extend(markdown_entries_per_file)
|
||||
return entries, entry_to_file_map
|
||||
|
||||
@staticmethod
|
||||
def convert_markdown_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
|
||||
"Convert each Markdown entries into a dictionary"
|
||||
entries = []
|
||||
entries: List[Entry] = []
|
||||
for parsed_entry in parsed_entries:
|
||||
raw_filename = entry_to_file_map[parsed_entry]
|
||||
|
||||
|
@ -108,13 +142,12 @@ class MarkdownToEntries(TextToEntries):
|
|||
entry_filename = urllib3.util.parse_url(raw_filename).url
|
||||
else:
|
||||
entry_filename = str(Path(raw_filename))
|
||||
stem = Path(raw_filename).stem
|
||||
|
||||
heading = parsed_entry.splitlines()[0] if re.search("^#+\s", parsed_entry) else ""
|
||||
# Append base filename to compiled entry for context to model
|
||||
# Increment heading level for heading entries and make filename as its top level heading
|
||||
prefix = f"# {stem}\n#" if heading else f"# {stem}\n"
|
||||
compiled_entry = f"{entry_filename}\n{prefix}{parsed_entry}"
|
||||
prefix = f"# {entry_filename}\n#" if heading else f"# {entry_filename}\n"
|
||||
compiled_entry = f"{prefix}{parsed_entry}"
|
||||
entries.append(
|
||||
Entry(
|
||||
compiled=compiled_entry,
|
||||
|
@ -127,8 +160,3 @@ class MarkdownToEntries(TextToEntries):
|
|||
logger.debug(f"Converted {len(parsed_entries)} markdown entries to dictionaries")
|
||||
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def convert_markdown_maps_to_jsonl(entries: List[Entry]):
|
||||
"Convert each Markdown entry to JSON and collate as JSONL"
|
||||
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.processor.content.org_mode import orgnode
|
||||
from khoj.processor.content.org_mode.orgnode import Orgnode
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import timer
|
||||
|
@ -21,9 +23,6 @@ class OrgToEntries(TextToEntries):
|
|||
def process(
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
# Extract required fields from config
|
||||
index_heading_entries = False
|
||||
|
||||
if not full_corpus:
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
|
@ -32,14 +31,12 @@ class OrgToEntries(TextToEntries):
|
|||
deletion_file_names = None
|
||||
|
||||
# Extract Entries from specified Org files
|
||||
with timer("Parse entries from org files into OrgNode objects", logger):
|
||||
entry_nodes, file_to_entries = self.extract_org_entries(files)
|
||||
|
||||
with timer("Convert OrgNodes into list of entries", logger):
|
||||
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
|
||||
max_tokens = 256
|
||||
with timer("Extract entries from specified Org files", logger):
|
||||
current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
|
||||
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=max_tokens)
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
|
@ -57,93 +54,165 @@ class OrgToEntries(TextToEntries):
|
|||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def extract_org_entries(org_files: dict[str, str]):
|
||||
def extract_org_entries(
|
||||
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=256
|
||||
) -> List[Entry]:
|
||||
"Extract entries from specified Org files"
|
||||
entries = []
|
||||
entry_to_file_map: List[Tuple[orgnode.Orgnode, str]] = []
|
||||
entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
|
||||
return OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map, index_heading_entries)
|
||||
|
||||
@staticmethod
|
||||
def extract_org_nodes(org_files: dict[str, str], max_tokens) -> Tuple[List[List[Orgnode]], Dict[Orgnode, str]]:
|
||||
"Extract org nodes from specified org files"
|
||||
entries: List[List[Orgnode]] = []
|
||||
entry_to_file_map: List[Tuple[Orgnode, str]] = []
|
||||
for org_file in org_files:
|
||||
filename = org_file
|
||||
file = org_files[org_file]
|
||||
try:
|
||||
org_file_entries = orgnode.makelist(file, filename)
|
||||
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
|
||||
entries.extend(org_file_entries)
|
||||
org_content = org_files[org_file]
|
||||
entries, entry_to_file_map = OrgToEntries.process_single_org_file(
|
||||
org_content, org_file, entries, entry_to_file_map, max_tokens
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to process file: {org_file}. This file will not be indexed.")
|
||||
logger.warning(e, exc_info=True)
|
||||
logger.error(f"Unable to process file: {org_file}. Skipped indexing it.\nError; {e}", exc_info=True)
|
||||
|
||||
return entries, dict(entry_to_file_map)
|
||||
|
||||
@staticmethod
|
||||
def process_single_org_file(org_content: str, org_file: str, entries: List, entry_to_file_map: List):
|
||||
# Process single org file. The org parser assumes that the file is a single org file and reads it from a buffer. We'll split the raw conetnt of this file by new line to mimic the same behavior.
|
||||
try:
|
||||
org_file_entries = orgnode.makelist(org_content, org_file)
|
||||
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
|
||||
entries.extend(org_file_entries)
|
||||
return entries, entry_to_file_map
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file: {org_file} with error: {e}", exc_info=True)
|
||||
def process_single_org_file(
|
||||
org_content: str,
|
||||
org_file: str,
|
||||
entries: List[List[Orgnode]],
|
||||
entry_to_file_map: List[Tuple[Orgnode, str]],
|
||||
max_tokens=256,
|
||||
ancestry: Dict[int, str] = {},
|
||||
) -> Tuple[List[List[Orgnode]], List[Tuple[Orgnode, str]]]:
|
||||
"""Parse org_content from org_file into OrgNode entries
|
||||
|
||||
Recurse down org file entries, one heading level at a time,
|
||||
until reach a leaf entry or the current entry tree fits max_tokens.
|
||||
|
||||
Parse recursion terminating entry (trees) into (a list of) OrgNode objects.
|
||||
"""
|
||||
# Prepend the org section's heading ancestry
|
||||
ancestry_string = "\n".join([f"{'*' * key} {ancestry[key]}" for key in sorted(ancestry.keys())])
|
||||
org_content_with_ancestry = f"{ancestry_string}{org_content}"
|
||||
|
||||
# If content is small or content has no children headings, save it as a single entry
|
||||
# Note: This is the terminating condition for this recursive function
|
||||
if len(TextToEntries.tokenizer(org_content_with_ancestry)) <= max_tokens or not re.search(
|
||||
rf"^\*{{{len(ancestry)+1},}}\s", org_content, re.MULTILINE
|
||||
):
|
||||
orgnode_content_with_ancestry = orgnode.makelist(org_content_with_ancestry, org_file)
|
||||
entry_to_file_map += zip(orgnode_content_with_ancestry, [org_file] * len(orgnode_content_with_ancestry))
|
||||
entries.extend([orgnode_content_with_ancestry])
|
||||
return entries, entry_to_file_map
|
||||
|
||||
# Split this entry tree into sections by the next heading level in it
|
||||
# Increment heading level until able to split entry into sections
|
||||
# A successful split will result in at least 2 sections
|
||||
next_heading_level = len(ancestry)
|
||||
sections: List[str] = []
|
||||
while len(sections) < 2:
|
||||
next_heading_level += 1
|
||||
sections = re.split(rf"(\n|^)(?=[*]{{{next_heading_level}}} .+\n?)", org_content, flags=re.MULTILINE)
|
||||
|
||||
# Recurse down each non-empty section after parsing its body, heading and ancestry
|
||||
for section in sections:
|
||||
# Skip empty sections
|
||||
if section.strip() == "":
|
||||
continue
|
||||
|
||||
# Extract the section body and (when present) the heading
|
||||
current_ancestry = ancestry.copy()
|
||||
first_non_empty_line = [line for line in section.split("\n") if line.strip() != ""][0]
|
||||
# If first non-empty line is a heading with expected heading level
|
||||
if re.search(rf"^\*{{{next_heading_level}}}\s", first_non_empty_line):
|
||||
# Extract the section body without the heading
|
||||
current_section_body = "\n".join(section.split(first_non_empty_line)[1:])
|
||||
# Parse the section heading into current section ancestry
|
||||
current_section_title = first_non_empty_line[next_heading_level:].strip()
|
||||
current_ancestry[next_heading_level] = current_section_title
|
||||
# Else process the section as just body text
|
||||
else:
|
||||
current_section_body = section
|
||||
|
||||
# Recurse down children of the current entry
|
||||
OrgToEntries.process_single_org_file(
|
||||
current_section_body,
|
||||
org_file,
|
||||
entries,
|
||||
entry_to_file_map,
|
||||
max_tokens,
|
||||
current_ancestry,
|
||||
)
|
||||
|
||||
return entries, entry_to_file_map
|
||||
|
||||
@staticmethod
|
||||
def convert_org_nodes_to_entries(
|
||||
parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False
|
||||
parsed_entries: List[List[Orgnode]],
|
||||
entry_to_file_map: Dict[Orgnode, str],
|
||||
index_heading_entries: bool = False,
|
||||
) -> List[Entry]:
|
||||
"Convert Org-Mode nodes into list of Entry objects"
|
||||
"""
|
||||
Convert OrgNode lists into list of Entry objects
|
||||
|
||||
Each list of OrgNodes is a parsed parent org tree or leaf node.
|
||||
Convert each list of these OrgNodes into a single Entry.
|
||||
"""
|
||||
entries: List[Entry] = []
|
||||
for parsed_entry in parsed_entries:
|
||||
if not parsed_entry.hasBody and not index_heading_entries:
|
||||
# Ignore title notes i.e notes with just headings and empty body
|
||||
continue
|
||||
for entry_group in parsed_entries:
|
||||
entry_heading, entry_compiled, entry_raw = "", "", ""
|
||||
for parsed_entry in entry_group:
|
||||
if not parsed_entry.hasBody and not index_heading_entries:
|
||||
# Ignore title notes i.e notes with just headings and empty body
|
||||
continue
|
||||
|
||||
todo_str = f"{parsed_entry.todo} " if parsed_entry.todo else ""
|
||||
todo_str = f"{parsed_entry.todo} " if parsed_entry.todo else ""
|
||||
|
||||
# Prepend ancestor headings, filename as top heading to entry for context
|
||||
ancestors_trail = " / ".join(parsed_entry.ancestors) or Path(entry_to_file_map[parsed_entry])
|
||||
if parsed_entry.heading:
|
||||
heading = f"* Path: {ancestors_trail}\n** {todo_str}{parsed_entry.heading}."
|
||||
else:
|
||||
heading = f"* Path: {ancestors_trail}."
|
||||
# Set base level to current org-node tree's root heading level
|
||||
if not entry_heading and parsed_entry.level > 0:
|
||||
base_level = parsed_entry.level
|
||||
# Indent entry by 1 heading level as ancestry is prepended as top level heading
|
||||
heading = f"{'*' * (parsed_entry.level-base_level+2)} {todo_str}" if parsed_entry.level > 0 else ""
|
||||
if parsed_entry.heading:
|
||||
heading += f"{parsed_entry.heading}."
|
||||
|
||||
compiled = heading
|
||||
if state.verbose > 2:
|
||||
logger.debug(f"Title: {heading}")
|
||||
# Prepend ancestor headings, filename as top heading to root parent entry for context
|
||||
# Children nodes do not need ancestors trail as root parent node will have it
|
||||
if not entry_heading:
|
||||
ancestors_trail = " / ".join(parsed_entry.ancestors) or Path(entry_to_file_map[parsed_entry])
|
||||
heading = f"* Path: {ancestors_trail}\n{heading}" if heading else f"* Path: {ancestors_trail}."
|
||||
|
||||
if parsed_entry.tags:
|
||||
tags_str = " ".join(parsed_entry.tags)
|
||||
compiled += f"\t {tags_str}."
|
||||
if state.verbose > 2:
|
||||
logger.debug(f"Tags: {tags_str}")
|
||||
compiled = heading
|
||||
|
||||
if parsed_entry.closed:
|
||||
compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.'
|
||||
if state.verbose > 2:
|
||||
logger.debug(f'Closed: {parsed_entry.closed.strftime("%Y-%m-%d")}')
|
||||
if parsed_entry.tags:
|
||||
tags_str = " ".join(parsed_entry.tags)
|
||||
compiled += f"\t {tags_str}."
|
||||
|
||||
if parsed_entry.scheduled:
|
||||
compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.'
|
||||
if state.verbose > 2:
|
||||
logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}')
|
||||
if parsed_entry.closed:
|
||||
compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.'
|
||||
|
||||
if parsed_entry.hasBody:
|
||||
compiled += f"\n {parsed_entry.body}"
|
||||
if state.verbose > 2:
|
||||
logger.debug(f"Body: {parsed_entry.body}")
|
||||
if parsed_entry.scheduled:
|
||||
compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.'
|
||||
|
||||
if compiled:
|
||||
if parsed_entry.hasBody:
|
||||
compiled += f"\n {parsed_entry.body}"
|
||||
|
||||
# Add the sub-entry contents to the entry
|
||||
entry_compiled += f"{compiled}"
|
||||
entry_raw += f"{parsed_entry}"
|
||||
if not entry_heading:
|
||||
entry_heading = heading
|
||||
|
||||
if entry_compiled:
|
||||
entries.append(
|
||||
Entry(
|
||||
compiled=compiled,
|
||||
raw=f"{parsed_entry}",
|
||||
heading=f"{heading}",
|
||||
compiled=entry_compiled,
|
||||
raw=entry_raw,
|
||||
heading=f"{entry_heading}",
|
||||
file=f"{entry_to_file_map[parsed_entry]}",
|
||||
)
|
||||
)
|
||||
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str:
|
||||
"Convert each Org-Mode entry to JSON and collate as JSONL"
|
||||
return "".join([f"{entry_dict.to_json()}\n" for entry_dict in entries])
|
||||
|
|
|
@ -37,7 +37,7 @@ import datetime
|
|||
import re
|
||||
from os.path import relpath
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
indent_regex = re.compile(r"^ *")
|
||||
|
||||
|
@ -58,7 +58,7 @@ def makelist_with_filepath(filename):
|
|||
return makelist(f, filename)
|
||||
|
||||
|
||||
def makelist(file, filename):
|
||||
def makelist(file, filename) -> List["Orgnode"]:
|
||||
"""
|
||||
Read an org-mode file and return a list of Orgnode objects
|
||||
created from this file.
|
||||
|
@ -80,16 +80,16 @@ def makelist(file, filename):
|
|||
} # populated from #+SEQ_TODO line
|
||||
level = ""
|
||||
heading = ""
|
||||
ancestor_headings = []
|
||||
ancestor_headings: List[str] = []
|
||||
bodytext = ""
|
||||
introtext = ""
|
||||
tags = list() # set of all tags in headline
|
||||
closed_date = ""
|
||||
sched_date = ""
|
||||
deadline_date = ""
|
||||
logbook = list()
|
||||
tags: List[str] = list() # set of all tags in headline
|
||||
closed_date: datetime.date = None
|
||||
sched_date: datetime.date = None
|
||||
deadline_date: datetime.date = None
|
||||
logbook: List[Tuple[datetime.datetime, datetime.datetime]] = list()
|
||||
nodelist: List[Orgnode] = list()
|
||||
property_map = dict()
|
||||
property_map: Dict[str, str] = dict()
|
||||
in_properties_drawer = False
|
||||
in_logbook_drawer = False
|
||||
file_title = f"{filename}"
|
||||
|
@ -102,13 +102,13 @@ def makelist(file, filename):
|
|||
thisNode = Orgnode(level, heading, bodytext, tags, ancestor_headings)
|
||||
if closed_date:
|
||||
thisNode.closed = closed_date
|
||||
closed_date = ""
|
||||
closed_date = None
|
||||
if sched_date:
|
||||
thisNode.scheduled = sched_date
|
||||
sched_date = ""
|
||||
sched_date = None
|
||||
if deadline_date:
|
||||
thisNode.deadline = deadline_date
|
||||
deadline_date = ""
|
||||
deadline_date = None
|
||||
if logbook:
|
||||
thisNode.logbook = logbook
|
||||
logbook = list()
|
||||
|
@ -116,7 +116,7 @@ def makelist(file, filename):
|
|||
nodelist.append(thisNode)
|
||||
property_map = {"LINE": f"file:{normalize_filename(filename)}::{ctr}"}
|
||||
previous_level = level
|
||||
previous_heading = heading
|
||||
previous_heading: str = heading
|
||||
level = heading_search.group(1)
|
||||
heading = heading_search.group(2)
|
||||
bodytext = ""
|
||||
|
@ -495,12 +495,13 @@ class Orgnode(object):
|
|||
if self._priority:
|
||||
n = n + "[#" + self._priority + "] "
|
||||
n = n + self._heading
|
||||
n = "%-60s " % n # hack - tags will start in column 62
|
||||
closecolon = ""
|
||||
for t in self._tags:
|
||||
n = n + ":" + t
|
||||
closecolon = ":"
|
||||
n = n + closecolon
|
||||
if self._tags:
|
||||
n = "%-60s " % n # hack - tags will start in column 62
|
||||
closecolon = ""
|
||||
for t in self._tags:
|
||||
n = n + ":" + t
|
||||
closecolon = ":"
|
||||
n = n + closecolon
|
||||
n = n + "\n"
|
||||
|
||||
# Get body indentation from first line of body
|
||||
|
|
|
@ -32,8 +32,8 @@ class PdfToEntries(TextToEntries):
|
|||
deletion_file_names = None
|
||||
|
||||
# Extract Entries from specified Pdf files
|
||||
with timer("Parse entries from PDF files into dictionaries", logger):
|
||||
current_entries = PdfToEntries.convert_pdf_entries_to_maps(*PdfToEntries.extract_pdf_entries(files))
|
||||
with timer("Extract entries from specified PDF files", logger):
|
||||
current_entries = PdfToEntries.extract_pdf_entries(files)
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
|
@ -55,11 +55,11 @@ class PdfToEntries(TextToEntries):
|
|||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def extract_pdf_entries(pdf_files):
|
||||
def extract_pdf_entries(pdf_files) -> List[Entry]:
|
||||
"""Extract entries by page from specified PDF files"""
|
||||
|
||||
entries = []
|
||||
entry_to_location_map = []
|
||||
entries: List[str] = []
|
||||
entry_to_location_map: List[Tuple[str, str]] = []
|
||||
for pdf_file in pdf_files:
|
||||
try:
|
||||
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path
|
||||
|
@ -83,7 +83,7 @@ class PdfToEntries(TextToEntries):
|
|||
if os.path.exists(f"{tmp_file}"):
|
||||
os.remove(f"{tmp_file}")
|
||||
|
||||
return entries, dict(entry_to_location_map)
|
||||
return PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
|
||||
|
||||
@staticmethod
|
||||
def convert_pdf_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
|
||||
|
@ -106,8 +106,3 @@ class PdfToEntries(TextToEntries):
|
|||
logger.debug(f"Converted {len(parsed_entries)} PDF entries to dictionaries")
|
||||
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def convert_pdf_maps_to_jsonl(entries: List[Entry]):
|
||||
"Convert each PDF entry to JSON and collate as JSONL"
|
||||
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
||||
|
|
|
@ -42,8 +42,8 @@ class PlaintextToEntries(TextToEntries):
|
|||
logger.warning(e, exc_info=True)
|
||||
|
||||
# Extract Entries from specified plaintext files
|
||||
with timer("Parse entries from plaintext files", logger):
|
||||
current_entries = PlaintextToEntries.convert_plaintext_entries_to_maps(files)
|
||||
with timer("Parse entries from specified Plaintext files", logger):
|
||||
current_entries = PlaintextToEntries.extract_plaintext_entries(files)
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
|
@ -74,7 +74,7 @@ class PlaintextToEntries(TextToEntries):
|
|||
return soup.get_text(strip=True, separator="\n")
|
||||
|
||||
@staticmethod
|
||||
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
|
||||
def extract_plaintext_entries(entry_to_file_map: dict[str, str]) -> List[Entry]:
|
||||
"Convert each plaintext entries into a dictionary"
|
||||
entries = []
|
||||
for file, entry in entry_to_file_map.items():
|
||||
|
@ -87,8 +87,3 @@ class PlaintextToEntries(TextToEntries):
|
|||
)
|
||||
)
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def convert_entries_to_jsonl(entries: List[Entry]):
|
||||
"Convert each entry to JSON and collate as JSONL"
|
||||
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import repeat
|
||||
from typing import Any, Callable, List, Set, Tuple
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from tqdm import tqdm
|
||||
|
||||
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
||||
|
@ -34,6 +36,27 @@ class TextToEntries(ABC):
|
|||
def hash_func(key: str) -> Callable:
|
||||
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def remove_long_words(text: str, max_word_length: int = 500) -> str:
|
||||
"Remove words longer than max_word_length from text."
|
||||
# Split the string by words, keeping the delimiters
|
||||
splits = re.split(r"(\s+)", text) + [""]
|
||||
words_with_delimiters = list(zip(splits[::2], splits[1::2]))
|
||||
|
||||
# Filter out long words while preserving delimiters in text
|
||||
filtered_text = [
|
||||
f"{word}{delimiter}"
|
||||
for word, delimiter in words_with_delimiters
|
||||
if not word.strip() or len(word.strip()) <= max_word_length
|
||||
]
|
||||
|
||||
return "".join(filtered_text)
|
||||
|
||||
@staticmethod
|
||||
def tokenizer(text: str) -> List[str]:
|
||||
"Tokenize text into words."
|
||||
return text.split()
|
||||
|
||||
@staticmethod
|
||||
def split_entries_by_max_tokens(
|
||||
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500
|
||||
|
@ -44,24 +67,30 @@ class TextToEntries(ABC):
|
|||
if is_none_or_empty(entry.compiled):
|
||||
continue
|
||||
|
||||
# Split entry into words
|
||||
compiled_entry_words = [word for word in entry.compiled.split(" ") if word != ""]
|
||||
|
||||
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
|
||||
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
|
||||
# Split entry into chunks of max_tokens
|
||||
# Use chunking preference order: paragraphs > sentences > words > characters
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=max_tokens,
|
||||
separators=["\n\n", "\n", "!", "?", ".", " ", "\t", ""],
|
||||
keep_separator=True,
|
||||
length_function=lambda chunk: len(TextToEntries.tokenizer(chunk)),
|
||||
chunk_overlap=0,
|
||||
)
|
||||
chunked_entry_chunks = text_splitter.split_text(entry.compiled)
|
||||
corpus_id = uuid.uuid4()
|
||||
|
||||
# Split entry into chunks of max tokens
|
||||
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
|
||||
compiled_entry_words_chunk = compiled_entry_words[chunk_index : chunk_index + max_tokens]
|
||||
compiled_entry_chunk = " ".join(compiled_entry_words_chunk)
|
||||
|
||||
# Create heading prefixed entry from each chunk
|
||||
for chunk_index, compiled_entry_chunk in enumerate(chunked_entry_chunks):
|
||||
# Prepend heading to all other chunks, the first chunk already has heading from original entry
|
||||
if chunk_index > 0:
|
||||
if chunk_index > 0 and entry.heading:
|
||||
# Snip heading to avoid crossing max_tokens limit
|
||||
# Keep last 100 characters of heading as entry heading more important than filename
|
||||
snipped_heading = entry.heading[-100:]
|
||||
compiled_entry_chunk = f"{snipped_heading}.\n{compiled_entry_chunk}"
|
||||
# Prepend snipped heading
|
||||
compiled_entry_chunk = f"{snipped_heading}\n{compiled_entry_chunk}"
|
||||
|
||||
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
|
||||
compiled_entry_chunk = TextToEntries.remove_long_words(compiled_entry_chunk, max_word_length)
|
||||
|
||||
# Clean entry of unwanted characters like \0 character
|
||||
compiled_entry_chunk = TextToEntries.clean_field(compiled_entry_chunk)
|
||||
|
@ -160,7 +189,7 @@ class TextToEntries(ABC):
|
|||
new_dates = []
|
||||
with timer("Indexed dates from added entries in", logger):
|
||||
for added_entry in added_entries:
|
||||
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.raw), repeat(added_entry))
|
||||
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.compiled), repeat(added_entry))
|
||||
dates_to_create = [
|
||||
EntryDates(date=date, entry=added_entry)
|
||||
for date, added_entry in dates_in_entries
|
||||
|
@ -244,11 +273,6 @@ class TextToEntries(ABC):
|
|||
|
||||
return entries_with_ids
|
||||
|
||||
@staticmethod
|
||||
def convert_text_maps_to_jsonl(entries: List[Entry]) -> str:
|
||||
# Convert each entry to JSON and write to JSONL file
|
||||
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
||||
|
||||
@staticmethod
|
||||
def clean_field(field: str) -> str:
|
||||
return field.replace("\0", "") if not is_none_or_empty(field) else ""
|
||||
|
|
|
@ -489,7 +489,7 @@ async def chat(
|
|||
common: CommonQueryParams,
|
||||
q: str,
|
||||
n: Optional[int] = 5,
|
||||
d: Optional[float] = 0.18,
|
||||
d: Optional[float] = 0.22,
|
||||
stream: Optional[bool] = False,
|
||||
title: Optional[str] = None,
|
||||
conversation_id: Optional[int] = None,
|
||||
|
|
|
@ -306,7 +306,7 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa
|
|||
user_query = quote("How to git install application?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.22", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
|
@ -325,7 +325,7 @@ def test_notes_search_no_results(client, search_config: SearchConfig, sample_org
|
|||
user_query = quote("How to find my goat?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.22", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
|
@ -409,7 +409,7 @@ def test_notes_search_requires_parent_context(
|
|||
user_query = quote("Install Khoj on Emacs")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.22", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -7,8 +6,8 @@ from khoj.utils.fs_syncer import get_markdown_files
|
|||
from khoj.utils.rawconfig import TextContentConfig
|
||||
|
||||
|
||||
def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
|
||||
"Convert files with no heading to jsonl."
|
||||
def test_extract_markdown_with_no_headings(tmp_path):
|
||||
"Convert markdown file with no heading to entry format."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
- Bullet point 1
|
||||
|
@ -17,30 +16,24 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
|
|||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
expected_heading = f"# {tmp_path.stem}"
|
||||
expected_heading = f"# {tmp_path}"
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entry_nodes, file_to_entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(
|
||||
MarkdownToEntries.convert_markdown_entries_to_maps(entry_nodes, file_to_entries)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
assert len(entries) == 1
|
||||
# Ensure raw entry with no headings do not get heading prefix prepended
|
||||
assert not jsonl_data[0]["raw"].startswith("#")
|
||||
assert not entries[0].raw.startswith("#")
|
||||
# Ensure compiled entry has filename prepended as top level heading
|
||||
assert expected_heading in jsonl_data[0]["compiled"]
|
||||
assert entries[0].compiled.startswith(expected_heading)
|
||||
# Ensure compiled entry also includes the file name
|
||||
assert str(tmp_path) in jsonl_data[0]["compiled"]
|
||||
assert str(tmp_path) in entries[0].compiled
|
||||
|
||||
|
||||
def test_single_markdown_entry_to_jsonl(tmp_path):
|
||||
"Convert markdown entry from single file to jsonl."
|
||||
def test_extract_single_markdown_entry(tmp_path):
|
||||
"Convert markdown from single file to entry format."
|
||||
# Arrange
|
||||
entry = f"""### Heading
|
||||
\t\r
|
||||
|
@ -52,20 +45,14 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
|
|||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(
|
||||
MarkdownToEntries.convert_markdown_entries_to_maps(entries, entry_to_file_map)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
assert len(entries) == 1
|
||||
|
||||
|
||||
def test_multiple_markdown_entries_to_jsonl(tmp_path):
|
||||
"Convert multiple markdown entries from single file to jsonl."
|
||||
def test_extract_multiple_markdown_entries(tmp_path):
|
||||
"Convert multiple markdown from single file to entry format."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
### Heading 1
|
||||
|
@ -81,19 +68,139 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
|
|||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entry_strings, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
||||
entries = MarkdownToEntries.convert_markdown_entries_to_maps(entry_strings, entry_to_file_map)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(entries)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 2
|
||||
assert len(entries) == 2
|
||||
# Ensure entry compiled strings include the markdown files they originate from
|
||||
assert all([tmp_path.stem in entry.compiled for entry in entries])
|
||||
|
||||
|
||||
def test_extract_entries_with_different_level_headings(tmp_path):
|
||||
"Extract markdown entries with different level headings."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
# Heading 1
|
||||
## Sub-Heading 1.1
|
||||
# Heading 2
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert entries[0].raw == "# Heading 1\n## Sub-Heading 1.1", "Ensure entry includes heading ancestory"
|
||||
assert entries[1].raw == "# Heading 2\n"
|
||||
|
||||
|
||||
def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
|
||||
"Extract markdown entries when deeper child level before shallower child level."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
# Heading 1
|
||||
#### Sub-Heading 1.1
|
||||
## Sub-Heading 1.2
|
||||
# Heading 2
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 3
|
||||
assert entries[0].raw == "# Heading 1\n#### Sub-Heading 1.1", "Ensure entry includes heading ancestory"
|
||||
assert entries[1].raw == "# Heading 1\n## Sub-Heading 1.2", "Ensure entry includes heading ancestory"
|
||||
assert entries[2].raw == "# Heading 2\n"
|
||||
|
||||
|
||||
def test_extract_entries_with_text_before_headings(tmp_path):
|
||||
"Extract markdown entries with some text before any headings."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
Text before headings
|
||||
# Heading 1
|
||||
body line 1
|
||||
## Heading 2
|
||||
body line 2
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 3
|
||||
assert entries[0].raw == "\nText before headings"
|
||||
assert entries[1].raw == "# Heading 1\nbody line 1"
|
||||
assert entries[2].raw == "# Heading 1\n## Heading 2\nbody line 2\n", "Ensure raw entry includes heading ancestory"
|
||||
|
||||
|
||||
def test_parse_markdown_file_into_single_entry_if_small(tmp_path):
|
||||
"Parse markdown file into single entry if it fits within the token limits."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
# Heading 1
|
||||
body line 1
|
||||
## Subheading 1.1
|
||||
body line 1.1
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert entries[0].raw == entry
|
||||
|
||||
|
||||
def test_parse_markdown_entry_with_children_as_single_entry_if_small(tmp_path):
|
||||
"Parse markdown entry with child headings as single entry if it fits within the tokens limits."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
# Heading 1
|
||||
body line 1
|
||||
## Subheading 1.1
|
||||
body line 1.1
|
||||
# Heading 2
|
||||
body line 2
|
||||
## Subheading 2.1
|
||||
longer body line 2.1
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 3
|
||||
assert (
|
||||
entries[0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1"
|
||||
), "First entry includes children headings"
|
||||
assert entries[1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings"
|
||||
assert (
|
||||
entries[2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n"
|
||||
), "Third entry is second entries child heading"
|
||||
|
||||
|
||||
def test_get_markdown_files(tmp_path):
|
||||
"Ensure Markdown files specified via input-filter, input-files extracted"
|
||||
# Arrange
|
||||
|
@ -131,27 +238,6 @@ def test_get_markdown_files(tmp_path):
|
|||
assert set(extracted_org_files.keys()) == expected_files
|
||||
|
||||
|
||||
def test_extract_entries_with_different_level_headings(tmp_path):
|
||||
"Extract markdown entries with different level headings."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
# Heading 1
|
||||
## Heading 2
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries, _ = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert entries[0] == "# Heading 1"
|
||||
assert entries[1] == "## Heading 2"
|
||||
|
||||
|
||||
# Helper Functions
|
||||
def create_file(tmp_path: Path, entry=None, filename="test.md"):
|
||||
markdown_file = tmp_path / filename
|
||||
|
|
|
@ -56,7 +56,7 @@ def test_index_update_with_user2_inaccessible_user1(client, api_user2: KhojApiUs
|
|||
|
||||
# Assert
|
||||
assert update_response.status_code == 200
|
||||
assert len(results) == 5
|
||||
assert len(results) == 3
|
||||
for result in results:
|
||||
assert result["additional"]["file"] not in source_file_symbol
|
||||
|
||||
|
|
|
@ -470,10 +470,6 @@ async def test_websearch_with_operators(chat_client):
|
|||
["site:reddit.com" in response for response in responses]
|
||||
), "Expected a search query to include site:reddit.com but got: " + str(responses)
|
||||
|
||||
assert any(
|
||||
["after:2024/04/01" in response for response in responses]
|
||||
), "Expected a search query to include after:2024/04/01 but got: " + str(responses)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
|
@ -8,7 +8,7 @@ from khoj.utils.helpers import is_none_or_empty
|
|||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||
|
||||
|
||||
def test_configure_heading_entry_to_jsonl(tmp_path):
|
||||
def test_configure_indexing_heading_only_entries(tmp_path):
|
||||
"""Ensure entries with empty body are ignored, unless explicitly configured to index heading entries.
|
||||
Property drawers not considered Body. Ignore control characters for evaluating if Body empty."""
|
||||
# Arrange
|
||||
|
@ -26,24 +26,21 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
|
|||
for index_heading_entries in [True, False]:
|
||||
# Act
|
||||
# Extract entries into jsonl from specified Org files
|
||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
|
||||
OrgToEntries.convert_org_nodes_to_entries(
|
||||
*OrgToEntries.extract_org_entries(org_files=data), index_heading_entries=index_heading_entries
|
||||
)
|
||||
entries = OrgToEntries.extract_org_entries(
|
||||
org_files=data, index_heading_entries=index_heading_entries, max_tokens=3
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
if index_heading_entries:
|
||||
# Entry with empty body indexed when index_heading_entries set to True
|
||||
assert len(jsonl_data) == 1
|
||||
assert len(entries) == 1
|
||||
else:
|
||||
# Entry with empty body ignored when index_heading_entries set to False
|
||||
assert is_none_or_empty(jsonl_data)
|
||||
assert is_none_or_empty(entries)
|
||||
|
||||
|
||||
def test_entry_split_when_exceeds_max_words():
|
||||
"Ensure entries with compiled words exceeding max_words are split."
|
||||
def test_entry_split_when_exceeds_max_tokens():
|
||||
"Ensure entries with compiled words exceeding max_tokens are split."
|
||||
# Arrange
|
||||
tmp_path = "/tmp/test.org"
|
||||
entry = f"""*** Heading
|
||||
|
@ -57,29 +54,26 @@ def test_entry_split_when_exceeds_max_words():
|
|||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data)
|
||||
entries = OrgToEntries.extract_org_entries(org_files=data)
|
||||
|
||||
# Split each entry from specified Org files by max words
|
||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
|
||||
TextToEntries.split_entries_by_max_tokens(
|
||||
OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
|
||||
)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
# Split each entry from specified Org files by max tokens
|
||||
entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=6)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 2
|
||||
# Ensure compiled entries split by max_words start with entry heading (for search context)
|
||||
assert all([entry["compiled"].startswith(expected_heading) for entry in jsonl_data])
|
||||
assert len(entries) == 2
|
||||
# Ensure compiled entries split by max tokens start with entry heading (for search context)
|
||||
assert all([entry.compiled.startswith(expected_heading) for entry in entries])
|
||||
|
||||
|
||||
def test_entry_split_drops_large_words():
|
||||
"Ensure entries drops words larger than specified max word length from compiled version."
|
||||
# Arrange
|
||||
entry_text = f"""*** Heading
|
||||
\t\r
|
||||
Body Line 1
|
||||
"""
|
||||
entry_text = f"""First Line
|
||||
dog=1\n\r\t
|
||||
cat=10
|
||||
car=4
|
||||
book=2
|
||||
"""
|
||||
entry = Entry(raw=entry_text, compiled=entry_text)
|
||||
|
||||
# Act
|
||||
|
@ -87,11 +81,158 @@ def test_entry_split_drops_large_words():
|
|||
processed_entry = TextToEntries.split_entries_by_max_tokens([entry], max_word_length=5)[0]
|
||||
|
||||
# Assert
|
||||
# "Heading" dropped from compiled version because its over the set max word limit
|
||||
assert len(processed_entry.compiled.split()) == len(entry_text.split()) - 1
|
||||
# Ensure words larger than max word length are dropped
|
||||
# Ensure newline characters are considered as word boundaries for splitting words. See #620
|
||||
words_to_keep = ["First", "Line", "dog=1", "car=4"]
|
||||
words_to_drop = ["cat=10", "book=2"]
|
||||
assert all([word for word in words_to_keep if word in processed_entry.compiled])
|
||||
assert not any([word for word in words_to_drop if word in processed_entry.compiled])
|
||||
assert len(processed_entry.compiled.split()) == len(entry_text.split()) - 2
|
||||
|
||||
|
||||
def test_entry_with_body_to_jsonl(tmp_path):
|
||||
def test_parse_org_file_into_single_entry_if_small(tmp_path):
|
||||
"Parse org file into single entry if it fits within the token limits."
|
||||
# Arrange
|
||||
original_entry = f"""
|
||||
* Heading 1
|
||||
body line 1
|
||||
** Subheading 1.1
|
||||
body line 1.1
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": original_entry,
|
||||
}
|
||||
expected_entry = f"""
|
||||
* Heading 1
|
||||
body line 1
|
||||
|
||||
** Subheading 1.1
|
||||
body line 1.1
|
||||
|
||||
""".lstrip()
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12)
|
||||
for entry in extracted_entries:
|
||||
entry.raw = clean(entry.raw)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_entries) == 1
|
||||
assert entry.raw == expected_entry
|
||||
|
||||
|
||||
def test_parse_org_entry_with_children_as_single_entry_if_small(tmp_path):
|
||||
"Parse org entry with child headings as single entry only if it fits within the tokens limits."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
* Heading 1
|
||||
body line 1
|
||||
** Subheading 1.1
|
||||
body line 1.1
|
||||
* Heading 2
|
||||
body line 2
|
||||
** Subheading 2.1
|
||||
longer body line 2.1
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
first_expected_entry = f"""
|
||||
* Path: {tmp_path}
|
||||
** Heading 1.
|
||||
body line 1
|
||||
|
||||
*** Subheading 1.1.
|
||||
body line 1.1
|
||||
|
||||
""".lstrip()
|
||||
second_expected_entry = f"""
|
||||
* Path: {tmp_path}
|
||||
** Heading 2.
|
||||
body line 2
|
||||
|
||||
""".lstrip()
|
||||
third_expected_entry = f"""
|
||||
* Path: {tmp_path} / Heading 2
|
||||
** Subheading 2.1.
|
||||
longer body line 2.1
|
||||
|
||||
""".lstrip()
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_entries) == 3
|
||||
assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings"
|
||||
assert extracted_entries[1].compiled == second_expected_entry, "Second entry does not include children headings"
|
||||
assert extracted_entries[2].compiled == third_expected_entry, "Third entry is second entries child heading"
|
||||
|
||||
|
||||
def test_separate_sibling_org_entries_if_all_cannot_fit_in_token_limit(tmp_path):
|
||||
"Parse org sibling entries as separate entries only if it fits within the tokens limits."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
* Heading 1
|
||||
body line 1
|
||||
** Subheading 1.1
|
||||
body line 1.1
|
||||
* Heading 2
|
||||
body line 2
|
||||
** Subheading 2.1
|
||||
body line 2.1
|
||||
* Heading 3
|
||||
body line 3
|
||||
** Subheading 3.1
|
||||
body line 3.1
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
first_expected_entry = f"""
|
||||
* Path: {tmp_path}
|
||||
** Heading 1.
|
||||
body line 1
|
||||
|
||||
*** Subheading 1.1.
|
||||
body line 1.1
|
||||
|
||||
""".lstrip()
|
||||
second_expected_entry = f"""
|
||||
* Path: {tmp_path}
|
||||
** Heading 2.
|
||||
body line 2
|
||||
|
||||
*** Subheading 2.1.
|
||||
body line 2.1
|
||||
|
||||
""".lstrip()
|
||||
third_expected_entry = f"""
|
||||
* Path: {tmp_path}
|
||||
** Heading 3.
|
||||
body line 3
|
||||
|
||||
*** Subheading 3.1.
|
||||
body line 3.1
|
||||
|
||||
""".lstrip()
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
# Max tokens = 30 is in the middle of 2 entry (24 tokens) and 3 entry (36 tokens) tokens boundary
|
||||
# Where each sibling entry contains 12 tokens per sibling entry * 3 entries = 36 tokens
|
||||
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=30)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_entries) == 3
|
||||
assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings"
|
||||
assert extracted_entries[1].compiled == second_expected_entry, "Second entry includes children headings"
|
||||
assert extracted_entries[2].compiled == third_expected_entry, "Third entry includes children headings"
|
||||
|
||||
|
||||
def test_entry_with_body_to_entry(tmp_path):
|
||||
"Ensure entries with valid body text are loaded."
|
||||
# Arrange
|
||||
entry = f"""*** Heading
|
||||
|
@ -107,19 +248,13 @@ def test_entry_with_body_to_jsonl(tmp_path):
|
|||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
|
||||
OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
assert len(entries) == 1
|
||||
|
||||
|
||||
def test_file_with_entry_after_intro_text_to_jsonl(tmp_path):
|
||||
def test_file_with_entry_after_intro_text_to_entry(tmp_path):
|
||||
"Ensure intro text before any headings is indexed."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
|
@ -134,18 +269,13 @@ Intro text
|
|||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 2
|
||||
assert len(entries) == 2
|
||||
|
||||
|
||||
def test_file_with_no_headings_to_jsonl(tmp_path):
|
||||
def test_file_with_no_headings_to_entry(tmp_path):
|
||||
"Ensure files with no heading, only body text are loaded."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
|
@ -158,15 +288,10 @@ def test_file_with_no_headings_to_jsonl(tmp_path):
|
|||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
assert len(entries) == 1
|
||||
|
||||
|
||||
def test_get_org_files(tmp_path):
|
||||
|
@ -214,7 +339,8 @@ def test_extract_entries_with_different_level_headings(tmp_path):
|
|||
# Arrange
|
||||
entry = f"""
|
||||
* Heading 1
|
||||
** Heading 2
|
||||
** Sub-Heading 1.1
|
||||
* Heading 2
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
|
@ -222,12 +348,14 @@ def test_extract_entries_with_different_level_headings(tmp_path):
|
|||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries, _ = OrgToEntries.extract_org_entries(org_files=data)
|
||||
entries = OrgToEntries.extract_org_entries(org_files=data, index_heading_entries=True, max_tokens=3)
|
||||
for entry in entries:
|
||||
entry.raw = clean(f"{entry.raw}")
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert f"{entries[0]}".startswith("* Heading 1")
|
||||
assert f"{entries[1]}".startswith("** Heading 2")
|
||||
assert entries[0].raw == "* Heading 1\n** Sub-Heading 1.1\n", "Ensure entry includes heading ancestory"
|
||||
assert entries[1].raw == "* Heading 2\n"
|
||||
|
||||
|
||||
# Helper Functions
|
||||
|
@ -237,3 +365,8 @@ def create_file(tmp_path, entry=None, filename="test.org"):
|
|||
if entry:
|
||||
org_file.write_text(entry)
|
||||
return org_file
|
||||
|
||||
|
||||
def clean(entry):
|
||||
"Remove properties from entry for easier comparison."
|
||||
return re.sub(r"\n:PROPERTIES:(.*?):END:", "", entry, flags=re.DOTALL)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
||||
|
@ -15,16 +14,10 @@ def test_single_page_pdf_to_jsonl():
|
|||
pdf_bytes = f.read()
|
||||
|
||||
data = {"tests/data/pdf/singlepage.pdf": pdf_bytes}
|
||||
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Process Each Entry from All Pdf Files
|
||||
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(
|
||||
PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
assert len(entries) == 1
|
||||
|
||||
|
||||
def test_multi_page_pdf_to_jsonl():
|
||||
|
@ -35,16 +28,10 @@ def test_multi_page_pdf_to_jsonl():
|
|||
pdf_bytes = f.read()
|
||||
|
||||
data = {"tests/data/pdf/multipage.pdf": pdf_bytes}
|
||||
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Process Each Entry from All Pdf Files
|
||||
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(
|
||||
PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 6
|
||||
assert len(entries) == 6
|
||||
|
||||
|
||||
def test_ocr_page_pdf_to_jsonl():
|
||||
|
@ -55,10 +42,7 @@ def test_ocr_page_pdf_to_jsonl():
|
|||
pdf_bytes = f.read()
|
||||
|
||||
data = {"tests/data/pdf/ocr_samples.pdf": pdf_bytes}
|
||||
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Process Each Entry from All Pdf Files
|
||||
entries = PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
|
||||
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
assert len(entries) == 1
|
||||
assert "playing on a strip of marsh" in entries[0].raw
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -11,10 +10,10 @@ from khoj.utils.rawconfig import TextContentConfig
|
|||
def test_plaintext_file(tmp_path):
|
||||
"Convert files with no heading to jsonl."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
raw_entry = f"""
|
||||
Hi, I am a plaintext file and I have some plaintext words.
|
||||
"""
|
||||
plaintextfile = create_file(tmp_path, entry)
|
||||
plaintextfile = create_file(tmp_path, raw_entry)
|
||||
|
||||
filename = plaintextfile.stem
|
||||
|
||||
|
@ -22,25 +21,21 @@ def test_plaintext_file(tmp_path):
|
|||
# Extract Entries from specified plaintext files
|
||||
|
||||
data = {
|
||||
f"{plaintextfile}": entry,
|
||||
f"{plaintextfile}": raw_entry,
|
||||
}
|
||||
|
||||
maps = PlaintextToEntries.convert_plaintext_entries_to_maps(entry_to_file_map=data)
|
||||
entries = PlaintextToEntries.extract_plaintext_entries(entry_to_file_map=data)
|
||||
|
||||
# Convert each entry.file to absolute path to make them JSON serializable
|
||||
for map in maps:
|
||||
map.file = str(Path(map.file).absolute())
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = PlaintextToEntries.convert_entries_to_jsonl(maps)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
for entry in entries:
|
||||
entry.file = str(Path(entry.file).absolute())
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
assert len(entries) == 1
|
||||
# Ensure raw entry with no headings do not get heading prefix prepended
|
||||
assert not jsonl_data[0]["raw"].startswith("#")
|
||||
assert not entries[0].raw.startswith("#")
|
||||
# Ensure compiled entry has filename prepended as top level heading
|
||||
assert jsonl_data[0]["compiled"] == f"{filename}\n{entry}"
|
||||
assert entries[0].compiled == f"{filename}\n{raw_entry}"
|
||||
|
||||
|
||||
def test_get_plaintext_files(tmp_path):
|
||||
|
@ -98,11 +93,11 @@ def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
|
|||
extracted_plaintext_files = get_plaintext_files(config=config)
|
||||
|
||||
# Act
|
||||
maps = PlaintextToEntries.convert_plaintext_entries_to_maps(extracted_plaintext_files)
|
||||
entries = PlaintextToEntries.extract_plaintext_entries(extracted_plaintext_files)
|
||||
|
||||
# Assert
|
||||
assert len(maps) == 1
|
||||
assert "<div>" not in maps[0].raw
|
||||
assert len(entries) == 1
|
||||
assert "<div>" not in entries[0].raw
|
||||
|
||||
|
||||
# Helper Functions
|
||||
|
|
|
@ -57,18 +57,21 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_with_empty_file_creates_no_entries(
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser
|
||||
):
|
||||
# Arrange
|
||||
existing_entries = Entry.objects.filter(user=default_user).count()
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert "Deleted 8 entries. Created 0 new entries for user " in caplog.records[-1].message
|
||||
updated_entries = Entry.objects.filter(user=default_user).count()
|
||||
|
||||
assert existing_entries == 2
|
||||
assert updated_entries == 0
|
||||
verify_embeddings(0, default_user)
|
||||
|
||||
|
||||
|
@ -78,6 +81,7 @@ def test_text_indexer_deletes_embedding_before_regenerate(
|
|||
content_config: ContentConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
existing_entries = Entry.objects.filter(user=default_user).count()
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
|
@ -87,30 +91,18 @@ def test_text_indexer_deletes_embedding_before_regenerate(
|
|||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
updated_entries = Entry.objects.filter(user=default_user).count()
|
||||
assert existing_entries == 2
|
||||
assert updated_entries == 2
|
||||
assert "Deleting all entries for file type org" in caplog.text
|
||||
assert "Deleted 8 entries. Created 13 new entries for user " in caplog.records[-1].message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_batch_processes(content_config: ContentConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert "Deleted 8 entries. Created 13 new entries for user " in caplog.records[-1].message
|
||||
assert "Deleted 2 entries. Created 2 new entries for user " in caplog.records[-1].message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
existing_entries = Entry.objects.filter(user=default_user)
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
|
@ -127,6 +119,10 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
|
|||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
updated_entries = Entry.objects.filter(user=default_user)
|
||||
for entry in updated_entries:
|
||||
assert entry in existing_entries
|
||||
assert len(existing_entries) == len(updated_entries)
|
||||
assert "Deleting all entries for file type org" in initial_logs
|
||||
assert "Deleting all entries for file type org" not in final_logs
|
||||
|
||||
|
@ -192,7 +188,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
|
|||
|
||||
# Assert
|
||||
assert (
|
||||
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
|
||||
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
|
||||
), "new entry not split by max tokens"
|
||||
|
||||
|
||||
|
@ -250,16 +246,15 @@ conda activate khoj
|
|||
|
||||
# Assert
|
||||
assert (
|
||||
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
|
||||
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
|
||||
), "new entry not split by max tokens"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_regenerate_index_with_new_entry(
|
||||
content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog
|
||||
):
|
||||
def test_regenerate_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser):
|
||||
# Arrange
|
||||
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
initial_data = get_org_files(org_config)
|
||||
|
||||
|
@ -271,28 +266,34 @@ def test_regenerate_index_with_new_entry(
|
|||
final_data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
|
||||
final_logs = caplog.text
|
||||
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
|
||||
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
# Assert
|
||||
assert "Deleted 8 entries. Created 13 new entries for user " in initial_logs
|
||||
assert "Deleted 13 entries. Created 14 new entries for user " in final_logs
|
||||
verify_embeddings(14, default_user)
|
||||
for entry in updated_entries1:
|
||||
assert entry in updated_entries2
|
||||
|
||||
assert not any([new_org_file.name in entry for entry in updated_entries1])
|
||||
assert not any([new_org_file.name in entry for entry in existing_entries])
|
||||
assert any([new_org_file.name in entry for entry in updated_entries2])
|
||||
|
||||
assert any(
|
||||
["Saw a super cute video of a chihuahua doing the Tango on Youtube" in entry for entry in updated_entries2]
|
||||
)
|
||||
verify_embeddings(3, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_duplicate_entries_in_stable_order(
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser
|
||||
):
|
||||
# Arrange
|
||||
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
|
||||
# Insert org-mode entries with same compiled form into new org file
|
||||
|
@ -304,30 +305,33 @@ def test_update_index_with_duplicate_entries_in_stable_order(
|
|||
|
||||
# Act
|
||||
# generate embeddings, entries, notes model from scratch after adding new org-mode file
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# update embeddings, entries, notes model with no new changes
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert "Deleted 8 entries. Created 1 new entries for user " in initial_logs
|
||||
assert "Deleted 0 entries. Created 0 new entries for user " in final_logs
|
||||
for entry in existing_entries:
|
||||
assert entry not in updated_entries1
|
||||
|
||||
for entry in updated_entries1:
|
||||
assert entry in updated_entries2
|
||||
|
||||
assert len(existing_entries) == 2
|
||||
assert len(updated_entries1) == len(updated_entries2)
|
||||
verify_embeddings(1, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
|
||||
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser):
|
||||
# Arrange
|
||||
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
|
||||
# Insert org-mode entries with same compiled form into new org file
|
||||
|
@ -344,33 +348,34 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
|
|||
|
||||
# Act
|
||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
|
||||
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert "Deleted 8 entries. Created 2 new entries for user " in initial_logs
|
||||
assert "Deleted 1 entries. Created 0 new entries for user " in final_logs
|
||||
for entry in existing_entries:
|
||||
assert entry not in updated_entries1
|
||||
|
||||
# verify the entry in updated_entries2 is a subset of updated_entries1
|
||||
for entry in updated_entries1:
|
||||
assert entry not in updated_entries2
|
||||
|
||||
for entry in updated_entries2:
|
||||
assert entry in updated_entries1[0]
|
||||
|
||||
verify_embeddings(1, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog):
|
||||
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser):
|
||||
# Arrange
|
||||
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
|
||||
# append org-mode entry to first org input file in config
|
||||
with open(new_org_file, "w") as f:
|
||||
|
@ -381,14 +386,14 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
|
|||
|
||||
# Act
|
||||
# update embeddings, entries with the newly added note
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
# Assert
|
||||
assert "Deleted 8 entries. Created 13 new entries for user " in initial_logs
|
||||
assert "Deleted 0 entries. Created 1 new entries for user " in final_logs
|
||||
verify_embeddings(14, default_user)
|
||||
for entry in existing_entries:
|
||||
assert entry not in updated_entries1
|
||||
assert len(updated_entries1) == len(existing_entries) + 1
|
||||
verify_embeddings(3, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
|
Loading…
Reference in a new issue