Create wrapper function to get entries from org, md, pdf & text files

- Convert extract_org_entries function to actually extract org entries
  Previously it was extracting intermediary org-node objects instead
  Now it extracts the org-node objects from files and converts them
  into entries
- Create separate, new function to extract_org_nodes from files
- Similarly create wrapper funcs for md, pdf, plaintext to entries

- Update org, md, pdf, plaintext to entries tests to use the new
  simplified wrapper function to extract org entries
This commit is contained in:
Debanjum Singh Solanky 2024-02-09 16:04:41 +05:30
parent f01a12b1d2
commit 28105ee027
8 changed files with 71 additions and 94 deletions

View file

@ -32,10 +32,8 @@ class MarkdownToEntries(TextToEntries):
deletion_file_names = None
# Extract Entries from specified Markdown files
with timer("Parse entries from Markdown files into dictionaries", logger):
current_entries = MarkdownToEntries.convert_markdown_entries_to_maps(
*MarkdownToEntries.extract_markdown_entries(files)
)
with timer("Extract entries from specified Markdown files", logger):
current_entries = MarkdownToEntries.extract_markdown_entries(files)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@ -57,13 +55,10 @@ class MarkdownToEntries(TextToEntries):
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_markdown_entries(markdown_files):
def extract_markdown_entries(markdown_files) -> List[Entry]:
"Extract entries by heading from specified Markdown files"
# Regex to extract Markdown Entries by Heading
entries = []
entry_to_file_map = []
entries: List[str] = []
entry_to_file_map: List[Tuple[str, Path]] = []
for markdown_file in markdown_files:
try:
markdown_content = markdown_files[markdown_file]
@ -71,18 +66,19 @@ class MarkdownToEntries(TextToEntries):
markdown_content, markdown_file, entries, entry_to_file_map
)
except Exception as e:
logger.warning(f"Unable to process file: {markdown_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
logger.warning(
f"Unable to process file: {markdown_file}. This file will not be indexed.\n{e}", exc_info=True
)
return entries, dict(entry_to_file_map)
return MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map))
@staticmethod
def process_single_markdown_file(
markdown_content: str, markdown_file: Path, entries: List, entry_to_file_map: List
markdown_content: str, markdown_file: Path, entries: List[str], entry_to_file_map: List[Tuple[str, Path]]
):
markdown_heading_regex = r"^#"
markdown_entries_per_file = []
markdown_entries_per_file: List[str] = []
any_headings = re.search(markdown_heading_regex, markdown_content, flags=re.MULTILINE)
for entry in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE):
# Add heading level as the regex split removed it from entries with headings
@ -98,7 +94,7 @@ class MarkdownToEntries(TextToEntries):
@staticmethod
def convert_markdown_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
"Convert each Markdown entries into a dictionary"
entries = []
entries: List[Entry] = []
for parsed_entry in parsed_entries:
raw_filename = entry_to_file_map[parsed_entry]

View file

@ -21,9 +21,6 @@ class OrgToEntries(TextToEntries):
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config
index_heading_entries = False
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
@ -32,11 +29,8 @@ class OrgToEntries(TextToEntries):
deletion_file_names = None
# Extract Entries from specified Org files
with timer("Parse entries from org files into OrgNode objects", logger):
entry_nodes, file_to_entries = self.extract_org_entries(files)
with timer("Convert OrgNodes into list of entries", logger):
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
with timer("Extract entries from specified Org files", logger):
current_entries = self.extract_org_entries(files)
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
@ -57,9 +51,18 @@ class OrgToEntries(TextToEntries):
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_org_entries(org_files: dict[str, str]):
def extract_org_entries(org_files: dict[str, str], index_heading_entries: bool = False):
"Extract entries from specified Org files"
entries = []
with timer("Parse entries from org files into OrgNode objects", logger):
entry_nodes, file_to_entries = OrgToEntries.extract_org_nodes(org_files)
with timer("Convert OrgNodes into list of entries", logger):
return OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
@staticmethod
def extract_org_nodes(org_files: dict[str, str]):
"Extract org nodes from specified org files"
entry_nodes = []
entry_to_file_map: List[Tuple[orgnode.Orgnode, str]] = []
for org_file in org_files:
filename = org_file
@ -67,16 +70,17 @@ class OrgToEntries(TextToEntries):
try:
org_file_entries = orgnode.makelist(file, filename)
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
entries.extend(org_file_entries)
entry_nodes.extend(org_file_entries)
except Exception as e:
logger.warning(f"Unable to process file: {org_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
return entries, dict(entry_to_file_map)
return entry_nodes, dict(entry_to_file_map)
@staticmethod
def process_single_org_file(org_content: str, org_file: str, entries: List, entry_to_file_map: List):
# Process single org file. The org parser assumes that the file is a single org file and reads it from a buffer. We'll split the raw conetnt of this file by new line to mimic the same behavior.
# Process single org file. The org parser assumes that the file is a single org file and reads it from a buffer.
# We'll split the raw content of this file by new line to mimic the same behavior.
try:
org_file_entries = orgnode.makelist(org_content, org_file)
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))

View file

@ -32,8 +32,8 @@ class PdfToEntries(TextToEntries):
deletion_file_names = None
# Extract Entries from specified Pdf files
with timer("Parse entries from PDF files into dictionaries", logger):
current_entries = PdfToEntries.convert_pdf_entries_to_maps(*PdfToEntries.extract_pdf_entries(files))
with timer("Extract entries from specified PDF files", logger):
current_entries = PdfToEntries.extract_pdf_entries(files)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@ -55,11 +55,11 @@ class PdfToEntries(TextToEntries):
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_pdf_entries(pdf_files):
def extract_pdf_entries(pdf_files) -> List[Entry]:
"""Extract entries by page from specified PDF files"""
entries = []
entry_to_location_map = []
entries: List[str] = []
entry_to_location_map: List[Tuple[str, str]] = []
for pdf_file in pdf_files:
try:
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path
@ -83,7 +83,7 @@ class PdfToEntries(TextToEntries):
if os.path.exists(f"{tmp_file}"):
os.remove(f"{tmp_file}")
return entries, dict(entry_to_location_map)
return PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
@staticmethod
def convert_pdf_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:

View file

@ -42,8 +42,8 @@ class PlaintextToEntries(TextToEntries):
logger.warning(e, exc_info=True)
# Extract Entries from specified plaintext files
with timer("Parse entries from plaintext files", logger):
current_entries = PlaintextToEntries.convert_plaintext_entries_to_maps(files)
with timer("Parse entries from specified Plaintext files", logger):
current_entries = PlaintextToEntries.extract_plaintext_entries(files)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@ -74,7 +74,7 @@ class PlaintextToEntries(TextToEntries):
return soup.get_text(strip=True, separator="\n")
@staticmethod
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
def extract_plaintext_entries(entry_to_file_map: dict[str, str]) -> List[Entry]:
"Convert each plaintext entries into a dictionary"
entries = []
for file, entry in entry_to_file_map.items():

View file

@ -21,12 +21,10 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Markdown files
entry_nodes, file_to_entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
# Process Each Entry from All Notes Files
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(
MarkdownToEntries.convert_markdown_entries_to_maps(entry_nodes, file_to_entries)
)
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@ -52,12 +50,10 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Markdown files
entries, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
# Process Each Entry from All Notes Files
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(
MarkdownToEntries.convert_markdown_entries_to_maps(entries, entry_to_file_map)
)
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@ -81,8 +77,7 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Markdown files
entry_strings, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
entries = MarkdownToEntries.convert_markdown_entries_to_maps(entry_strings, entry_to_file_map)
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
# Process Each Entry from All Notes Files
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(entries)
@ -144,12 +139,12 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Act
# Extract Entries from specified Markdown files
entries, _ = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
# Assert
assert len(entries) == 2
assert entries[0] == "# Heading 1"
assert entries[1] == "## Heading 2"
assert entries[0].raw == "# Heading 1"
assert entries[1].raw == "## Heading 2"
# Helper Functions

View file

@ -27,9 +27,7 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
# Act
# Extract entries into jsonl from specified Org files
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
OrgToEntries.convert_org_nodes_to_entries(
*OrgToEntries.extract_org_entries(org_files=data), index_heading_entries=index_heading_entries
)
OrgToEntries.extract_org_entries(org_files=data, index_heading_entries=index_heading_entries)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@ -57,13 +55,11 @@ def test_entry_split_when_exceeds_max_words():
# Act
# Extract Entries from specified Org files
entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data)
entries = OrgToEntries.extract_org_entries(org_files=data)
# Split each entry from specified Org files by max words
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
TextToEntries.split_entries_by_max_tokens(
OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
)
TextToEntries.split_entries_by_max_tokens(entries, max_tokens=4)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@ -107,12 +103,7 @@ def test_entry_with_body_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Org files
entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data)
# Process Each Entry from All Notes Files
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map)
)
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(OrgToEntries.extract_org_entries(org_files=data))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@ -134,10 +125,9 @@ Intro text
# Act
# Extract Entries from specified Org files
entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data)
entries = OrgToEntries.extract_org_entries(org_files=data)
# Process Each Entry from All Notes Files
entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@ -158,10 +148,9 @@ def test_file_with_no_headings_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Org files
entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data)
entries = OrgToEntries.extract_org_entries(org_files=data)
# Process Each Entry from All Notes Files
entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@ -222,12 +211,12 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Act
# Extract Entries from specified Org files
entries, _ = OrgToEntries.extract_org_entries(org_files=data)
entries = OrgToEntries.extract_org_entries(org_files=data, index_heading_entries=True)
# Assert
assert len(entries) == 2
assert f"{entries[0]}".startswith("* Heading 1")
assert f"{entries[1]}".startswith("** Heading 2")
assert f"{entries[0].raw}".startswith("* Heading 1")
assert f"{entries[1].raw}".startswith("** Heading 2")
# Helper Functions

View file

@ -15,12 +15,10 @@ def test_single_page_pdf_to_jsonl():
pdf_bytes = f.read()
data = {"tests/data/pdf/singlepage.pdf": pdf_bytes}
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
# Process Each Entry from All Pdf Files
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(
PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
)
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@ -35,12 +33,10 @@ def test_multi_page_pdf_to_jsonl():
pdf_bytes = f.read()
data = {"tests/data/pdf/multipage.pdf": pdf_bytes}
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
# Process Each Entry from All Pdf Files
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(
PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
)
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@ -55,10 +51,7 @@ def test_ocr_page_pdf_to_jsonl():
pdf_bytes = f.read()
data = {"tests/data/pdf/ocr_samples.pdf": pdf_bytes}
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
# Process Each Entry from All Pdf Files
entries = PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
assert len(entries) == 1
assert "playing on a strip of marsh" in entries[0].raw

View file

@ -11,10 +11,10 @@ from khoj.utils.rawconfig import TextContentConfig
def test_plaintext_file(tmp_path):
"Convert files with no heading to jsonl."
# Arrange
entry = f"""
raw_entry = f"""
Hi, I am a plaintext file and I have some plaintext words.
"""
plaintextfile = create_file(tmp_path, entry)
plaintextfile = create_file(tmp_path, raw_entry)
filename = plaintextfile.stem
@ -22,17 +22,17 @@ def test_plaintext_file(tmp_path):
# Extract Entries from specified plaintext files
data = {
f"{plaintextfile}": entry,
f"{plaintextfile}": raw_entry,
}
maps = PlaintextToEntries.convert_plaintext_entries_to_maps(entry_to_file_map=data)
entries = PlaintextToEntries.extract_plaintext_entries(entry_to_file_map=data)
# Convert each entry.file to absolute path to make them JSON serializable
for map in maps:
map.file = str(Path(map.file).absolute())
for entry in entries:
entry.file = str(Path(entry.file).absolute())
# Process Each Entry from All Notes Files
jsonl_string = PlaintextToEntries.convert_entries_to_jsonl(maps)
jsonl_string = PlaintextToEntries.convert_entries_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@ -40,7 +40,7 @@ def test_plaintext_file(tmp_path):
# Ensure raw entry with no headings do not get heading prefix prepended
assert not jsonl_data[0]["raw"].startswith("#")
# Ensure compiled entry has filename prepended as top level heading
assert jsonl_data[0]["compiled"] == f"{filename}\n{entry}"
assert jsonl_data[0]["compiled"] == f"{filename}\n{raw_entry}"
def test_get_plaintext_files(tmp_path):
@ -98,11 +98,11 @@ def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
extracted_plaintext_files = get_plaintext_files(config=config)
# Act
maps = PlaintextToEntries.convert_plaintext_entries_to_maps(extracted_plaintext_files)
entries = PlaintextToEntries.extract_plaintext_entries(extracted_plaintext_files)
# Assert
assert len(maps) == 1
assert "<div>" not in maps[0].raw
assert len(entries) == 1
assert "<div>" not in entries[0].raw
# Helper Functions