Dedupe code by using single func to process an org file into entries

Add type hints to orgnode and org-to-entries packages
This commit is contained in:
Debanjum Singh Solanky 2024-02-11 00:34:04 +05:30
parent db2581459f
commit 44eab74888
3 changed files with 42 additions and 39 deletions

View file

@ -1,10 +1,11 @@
import logging
from pathlib import Path
from typing import Iterable, List, Tuple
from typing import Dict, List, Tuple
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.org_mode import orgnode
from khoj.processor.content.org_mode.orgnode import Orgnode
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils import state
from khoj.utils.helpers import timer
@ -51,7 +52,7 @@ class OrgToEntries(TextToEntries):
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_org_entries(org_files: dict[str, str], index_heading_entries: bool = False):
def extract_org_entries(org_files: dict[str, str], index_heading_entries: bool = False) -> List[Entry]:
"Extract entries from specified Org files"
with timer("Parse entries from org files into OrgNode objects", logger):
entry_nodes, file_to_entries = OrgToEntries.extract_org_nodes(org_files)
@ -60,35 +61,35 @@ class OrgToEntries(TextToEntries):
return OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
@staticmethod
def extract_org_nodes(org_files: dict[str, str]):
def extract_org_nodes(org_files: dict[str, str]) -> Tuple[List[Orgnode], Dict[Orgnode, str]]:
"Extract org nodes from specified org files"
entry_nodes = []
entry_to_file_map: List[Tuple[orgnode.Orgnode, str]] = []
entry_nodes: List[Orgnode] = []
entry_to_file_map: List[Tuple[Orgnode, str]] = []
for org_file in org_files:
filename = org_file
file = org_files[org_file]
try:
org_file_entries = orgnode.makelist(file, filename)
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
entry_nodes.extend(org_file_entries)
except Exception as e:
logger.warning(f"Unable to process file: {org_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
org_content = org_files[org_file]
entry_nodes, entry_to_file_map = OrgToEntries.process_single_org_file(
org_content, org_file, entry_nodes, entry_to_file_map
)
return entry_nodes, dict(entry_to_file_map)
@staticmethod
def process_single_org_file(org_content: str, org_file: str, entries: List, entry_to_file_map: List):
def process_single_org_file(
org_content: str,
org_file: str,
entries: List[Orgnode],
entry_to_file_map: List[Tuple[Orgnode, str]],
) -> Tuple[List[Orgnode], List[Tuple[Orgnode, str]]]:
# Process single org file. The org parser assumes that the file is a single org file and reads it from a buffer.
# We'll split the raw content of this file by new line to mimic the same behavior.
try:
org_file_entries = orgnode.makelist(org_content, org_file)
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
entries.extend(org_file_entries)
return entries, entry_to_file_map
except Exception as e:
logger.error(f"Error processing file: {org_file} with error: {e}", exc_info=True)
return entries, entry_to_file_map
logger.error(f"Unable to process file: {org_file}. Skipped indexing it.\nError; {e}", exc_info=True)
return entries, entry_to_file_map
@staticmethod
def convert_org_nodes_to_entries(

View file

@ -37,7 +37,7 @@ import datetime
import re
from os.path import relpath
from pathlib import Path
from typing import List
from typing import Dict, List, Tuple
indent_regex = re.compile(r"^ *")
@ -58,7 +58,7 @@ def makelist_with_filepath(filename):
return makelist(f, filename)
def makelist(file, filename):
def makelist(file, filename) -> List["Orgnode"]:
"""
Read an org-mode file and return a list of Orgnode objects
created from this file.
@ -80,16 +80,16 @@ def makelist(file, filename):
} # populated from #+SEQ_TODO line
level = ""
heading = ""
ancestor_headings = []
ancestor_headings: List[str] = []
bodytext = ""
introtext = ""
tags = list() # set of all tags in headline
closed_date = ""
sched_date = ""
deadline_date = ""
logbook = list()
tags: List[str] = list() # set of all tags in headline
closed_date: datetime.date = None
sched_date: datetime.date = None
deadline_date: datetime.date = None
logbook: List[Tuple[datetime.datetime, datetime.datetime]] = list()
nodelist: List[Orgnode] = list()
property_map = dict()
property_map: Dict[str, str] = dict()
in_properties_drawer = False
in_logbook_drawer = False
file_title = f"{filename}"
@ -102,13 +102,13 @@ def makelist(file, filename):
thisNode = Orgnode(level, heading, bodytext, tags, ancestor_headings)
if closed_date:
thisNode.closed = closed_date
closed_date = ""
closed_date = None
if sched_date:
thisNode.scheduled = sched_date
sched_date = ""
sched_date = None
if deadline_date:
thisNode.deadline = deadline_date
deadline_date = ""
deadline_date = None
if logbook:
thisNode.logbook = logbook
logbook = list()
@ -116,7 +116,7 @@ def makelist(file, filename):
nodelist.append(thisNode)
property_map = {"LINE": f"file:{normalize_filename(filename)}::{ctr}"}
previous_level = level
previous_heading = heading
previous_heading: str = heading
level = heading_search.group(1)
heading = heading_search.group(2)
bodytext = ""

View file

@ -37,8 +37,8 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
assert is_none_or_empty(entries)
def test_entry_split_when_exceeds_max_words():
"Ensure entries with compiled words exceeding max_words are split."
def test_entry_split_when_exceeds_max_tokens():
"Ensure entries with compiled words exceeding max_tokens are split."
# Arrange
tmp_path = "/tmp/test.org"
entry = f"""*** Heading
@ -81,7 +81,7 @@ def test_entry_split_drops_large_words():
assert len(processed_entry.compiled.split()) == len(entry_text.split()) - 1
def test_entry_with_body_to_jsonl(tmp_path):
def test_entry_with_body_to_entry(tmp_path):
"Ensure entries with valid body text are loaded."
# Arrange
entry = f"""*** Heading
@ -97,13 +97,13 @@ def test_entry_with_body_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Org files
entries = OrgToEntries.extract_org_entries(org_files=data)
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
# Assert
assert len(entries) == 1
def test_file_with_entry_after_intro_text_to_jsonl(tmp_path):
def test_file_with_entry_after_intro_text_to_entry(tmp_path):
"Ensure intro text before any headings is indexed."
# Arrange
entry = f"""
@ -188,7 +188,8 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Arrange
entry = f"""
* Heading 1
** Heading 2
** Sub-Heading 1.1
* Heading 2
"""
data = {
f"{tmp_path}": entry,
@ -199,9 +200,10 @@ def test_extract_entries_with_different_level_headings(tmp_path):
entries = OrgToEntries.extract_org_entries(org_files=data, index_heading_entries=True)
# Assert
assert len(entries) == 2
assert len(entries) == 3
assert f"{entries[0].raw}".startswith("* Heading 1")
assert f"{entries[1].raw}".startswith("** Heading 2")
assert f"{entries[1].raw}".startswith("** Sub-Heading 1.1")
assert f"{entries[2].raw}".startswith("* Heading 2")
# Helper Functions