diff --git a/src/khoj/processor/content/markdown/markdown_to_entries.py b/src/khoj/processor/content/markdown/markdown_to_entries.py index c0fea077..986d4911 100644 --- a/src/khoj/processor/content/markdown/markdown_to_entries.py +++ b/src/khoj/processor/content/markdown/markdown_to_entries.py @@ -31,13 +31,14 @@ class MarkdownToEntries(TextToEntries): else: deletion_file_names = None + max_tokens = 256 # Extract Entries from specified Markdown files with timer("Extract entries from specified Markdown files", logger): - current_entries = MarkdownToEntries.extract_markdown_entries(files) + current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens) # Split entries by max tokens supported by model with timer("Split entries by max token size supported by model", logger): - current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) + current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens) # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): @@ -55,7 +56,7 @@ class MarkdownToEntries(TextToEntries): return num_new_embeddings, num_deleted_embeddings @staticmethod - def extract_markdown_entries(markdown_files) -> List[Entry]: + def extract_markdown_entries(markdown_files, max_tokens=256) -> List[Entry]: "Extract entries by heading from specified Markdown files" entries: List[str] = [] entry_to_file_map: List[Tuple[str, Path]] = [] @@ -63,7 +64,7 @@ class MarkdownToEntries(TextToEntries): try: markdown_content = markdown_files[markdown_file] entries, entry_to_file_map = MarkdownToEntries.process_single_markdown_file( - markdown_content, markdown_file, entries, entry_to_file_map + markdown_content, markdown_file, entries, entry_to_file_map, max_tokens ) except Exception as e: logger.warning( @@ -74,8 +75,17 @@ class MarkdownToEntries(TextToEntries): @staticmethod def process_single_markdown_file( - markdown_content: str, markdown_file: Path, entries: List[str], entry_to_file_map: List[Tuple[str, Path]] + markdown_content: str, + markdown_file: Path, + entries: List[str], + entry_to_file_map: List[Tuple[str, Path]], + max_tokens=256, ): + if len(TextToEntries.tokenizer(markdown_content)) <= max_tokens: + entry_to_file_map += [(markdown_content, markdown_file)] + entries.extend([markdown_content]) + return entries, entry_to_file_map + headers_to_split_on = [("#", "1"), ("##", "2"), ("###", "3"), ("####", "4"), ("#####", "5"), ("######", "6")] reversed_headers_to_split_on = list(reversed(headers_to_split_on)) markdown_entries_per_file: List[str] = [] diff --git a/tests/test_markdown_to_entries.py b/tests/test_markdown_to_entries.py index 174c6c4d..18d43791 100644 --- a/tests/test_markdown_to_entries.py +++ b/tests/test_markdown_to_entries.py @@ -20,7 +20,7 @@ def test_extract_markdown_with_no_headings(tmp_path): # Act # Extract Entries from specified Markdown files - entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data) + entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3) # Assert assert len(entries) == 1 @@ -45,7 +45,7 @@ def test_extract_single_markdown_entry(tmp_path): # Act # Extract Entries from specified Markdown files - entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data) + entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3) # Assert assert len(entries) == 1 @@ -68,7 +68,7 @@ def test_extract_multiple_markdown_entries(tmp_path): # Act # Extract Entries from specified Markdown files - entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data) + entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3) # Assert assert len(entries) == 2 @@ -127,7 +127,7 @@ def test_extract_entries_with_different_level_headings(tmp_path): # Act # Extract Entries from specified Markdown files - entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data) + entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3) # Assert assert len(entries) == 3 @@ -161,6 +161,28 @@ body line 2 assert entries[2].raw == "# Heading 1\n## Heading 2\nbody line 2", "Ensure raw entry includes heading ancestory" +def test_parse_markdown_file_into_single_entry_if_small(tmp_path): + "Parse markdown file into single entry if it fits within the token limits." + # Arrange + entry = f""" +# Heading 1 +body line 1 +## Subheading 1.1 +body line 1.1 +""" + data = { + f"{tmp_path}": entry, + } + + # Act + # Extract Entries from specified Markdown files + entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12) + + # Assert + assert len(entries) == 1 + assert entries[0].raw == entry + + # Helper Functions def create_file(tmp_path: Path, entry=None, filename="test.md"): markdown_file = tmp_path / filename