diff --git a/src/processor/text_to_jsonl.py b/src/processor/text_to_jsonl.py index 2f5e7e40..0eb60e6c 100644 --- a/src/processor/text_to_jsonl.py +++ b/src/processor/text_to_jsonl.py @@ -23,6 +23,19 @@ class TextToJsonl(ABC): def hash_func(key: str) -> Callable: return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest() + @staticmethod + def split_entries_by_max_tokens(entries: list[Entry], max_tokens: int=256) -> list[Entry]: + "Split entries if compiled entry length exceeds the max tokens supported by the ML model." + chunked_entries: list[Entry] = [] + for entry in entries: + compiled_entry_words = entry.compiled.split() + for chunk_index in range(0, len(compiled_entry_words), max_tokens): + compiled_entry_words_chunk = compiled_entry_words[chunk_index:chunk_index + max_tokens] + compiled_entry_chunk = ' '.join(compiled_entry_words_chunk) + entry_chunk = Entry(compiled=compiled_entry_chunk, raw=entry.raw, file=entry.file) + chunked_entries.append(entry_chunk) + return chunked_entries + def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]: # Hash all current and previous entries to identify new entries start = time.time() diff --git a/tests/test_org_to_jsonl.py b/tests/test_org_to_jsonl.py index 2dbedcd0..ee9aae16 100644 --- a/tests/test_org_to_jsonl.py +++ b/tests/test_org_to_jsonl.py @@ -3,6 +3,7 @@ import json # Internal Packages from src.processor.org_mode.org_to_jsonl import OrgToJsonl +from src.processor.text_to_jsonl import TextToJsonl from src.utils.helpers import is_none_or_empty @@ -35,6 +36,34 @@ def test_configure_heading_entry_to_jsonl(tmp_path): assert is_none_or_empty(jsonl_data) +def test_entry_split_when_exceeds_max_words(tmp_path): + "Ensure entries with compiled words exceeding max_words are split." + # Arrange + entry = f'''*** Heading + :PROPERTIES: + :ID: 42-42-42 + :END: + \t\r + Body Line 1 + ''' + orgfile = create_file(tmp_path, entry) + + # Act + # Extract Entries from specified Org files + entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile]) + + # Split Each Entry from specified Org files by Max Words + jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl( + TextToJsonl.split_entries_by_max_tokens( + OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), + max_tokens = 2) + ) + jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + + # Assert + assert len(jsonl_data) == 2 + + def test_entry_with_body_to_jsonl(tmp_path): "Ensure entries with valid body text are loaded." # Arrange