diff --git a/src/khoj/processor/content/plaintext/plaintext_to_entries.py b/src/khoj/processor/content/plaintext/plaintext_to_entries.py index 45ac5047..c14bc359 100644 --- a/src/khoj/processor/content/plaintext/plaintext_to_entries.py +++ b/src/khoj/processor/content/plaintext/plaintext_to_entries.py @@ -1,7 +1,9 @@ import logging +import re from pathlib import Path -from typing import List, Tuple +from typing import Dict, List, Tuple +import urllib3 from bs4 import BeautifulSoup from khoj.database.models import Entry as DbEntry @@ -28,21 +30,8 @@ class PlaintextToEntries(TextToEntries): else: deletion_file_names = None - with timer("Scrub plaintext files and extract text", logger): - for file in files: - try: - plaintext_content = files[file] - if file.endswith(("html", "htm", "xml")): - plaintext_content = PlaintextToEntries.extract_html_content( - plaintext_content, file.split(".")[-1] - ) - files[file] = plaintext_content - except Exception as e: - logger.warning(f"Unable to read file: {file} as plaintext. Skipping file.") - logger.warning(e, exc_info=True) - # Extract Entries from specified plaintext files - with timer("Parse entries from specified Plaintext files", logger): + with timer("Extract entries from specified Plaintext files", logger): current_entries = PlaintextToEntries.extract_plaintext_entries(files) # Split entries by max tokens supported by model @@ -74,16 +63,57 @@ class PlaintextToEntries(TextToEntries): return soup.get_text(strip=True, separator="\n") @staticmethod - def extract_plaintext_entries(entry_to_file_map: dict[str, str]) -> List[Entry]: - "Convert each plaintext entries into a dictionary" - entries = [] - for file, entry in entry_to_file_map.items(): + def extract_plaintext_entries(text_files: Dict[str, str]) -> List[Entry]: + entries: List[str] = [] + entry_to_file_map: List[Tuple[str, str]] = [] + for text_file in text_files: + try: + text_content = text_files[text_file] + entries, entry_to_file_map = PlaintextToEntries.process_single_plaintext_file( + text_content, text_file, entries, entry_to_file_map + ) + except Exception as e: + logger.warning(f"Unable to read file: {text_file} as plaintext. Skipping file.") + logger.warning(e, exc_info=True) + + # Extract Entries from specified plaintext files + return PlaintextToEntries.convert_text_files_to_entries(entries, dict(entry_to_file_map)) + + @staticmethod + def process_single_plaintext_file( + text_content: str, + text_file: str, + entries: List[str], + entry_to_file_map: List[Tuple[str, str]], + ) -> Tuple[List[str], List[Tuple[str, str]]]: + if text_file.endswith(("html", "htm", "xml")): + text_content = PlaintextToEntries.extract_html_content(text_content, text_file.split(".")[-1]) + entry_to_file_map += [(text_content, text_file)] + entries.extend([text_content]) + return entries, entry_to_file_map + + @staticmethod + def convert_text_files_to_entries(parsed_entries: List[str], entry_to_file_map: dict[str, str]) -> List[Entry]: + "Convert each plaintext file into an entry" + entries: List[Entry] = [] + for parsed_entry in parsed_entries: + raw_filename = entry_to_file_map[parsed_entry] + # Check if raw_filename is a URL. If so, save it as is. If not, convert it to a Path. + if type(raw_filename) == str and re.search(r"^https?://", raw_filename): + # Escape the URL to avoid issues with special characters + entry_filename = urllib3.util.parse_url(raw_filename).url + else: + entry_filename = raw_filename + + # Append base filename to compiled entry for context to model entries.append( Entry( - raw=entry, - file=file, - compiled=f"{Path(file).stem}\n{entry}", - heading=Path(file).stem, + raw=parsed_entry, + file=f"{entry_filename}", + compiled=f"{entry_filename}\n{parsed_entry}", + heading=entry_filename, ) ) + + logger.debug(f"Converted {len(parsed_entries)} plaintext files to entries") return entries diff --git a/tests/test_client.py b/tests/test_client.py index fb73ca3c..bb565794 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,6 +1,5 @@ # Standard Modules import os -from io import BytesIO from urllib.parse import quote import pytest diff --git a/tests/test_plaintext_to_entries.py b/tests/test_plaintext_to_entries.py index 41d841fc..ba908997 100644 --- a/tests/test_plaintext_to_entries.py +++ b/tests/test_plaintext_to_entries.py @@ -15,8 +15,6 @@ def test_plaintext_file(tmp_path): """ plaintextfile = create_file(tmp_path, raw_entry) - filename = plaintextfile.stem - # Act # Extract Entries from specified plaintext files @@ -24,7 +22,7 @@ def test_plaintext_file(tmp_path): f"{plaintextfile}": raw_entry, } - entries = PlaintextToEntries.extract_plaintext_entries(entry_to_file_map=data) + entries = PlaintextToEntries.extract_plaintext_entries(data) # Convert each entry.file to absolute path to make them JSON serializable for entry in entries: @@ -35,7 +33,7 @@ def test_plaintext_file(tmp_path): # Ensure raw entry with no headings do not get heading prefix prepended assert not entries[0].raw.startswith("#") # Ensure compiled entry has filename prepended as top level heading - assert entries[0].compiled == f"{filename}\n{raw_entry}" + assert entries[0].compiled == f"{plaintextfile}\n{raw_entry}" def test_get_plaintext_files(tmp_path): @@ -100,6 +98,35 @@ def test_parse_html_plaintext_file(content_config, default_user: KhojUser): assert "
" not in entries[0].raw +def test_large_plaintext_file_split_into_multiple_entries(tmp_path): + "Convert files with no heading to jsonl." + # Arrange + max_tokens = 256 + normal_entry = " ".join([f"{number}" for number in range(max_tokens - 1)]) + large_entry = " ".join([f"{number}" for number in range(max_tokens)]) + + normal_plaintextfile = create_file(tmp_path, normal_entry) + large_plaintextfile = create_file(tmp_path, large_entry) + + normal_data = {f"{normal_plaintextfile}": normal_entry} + large_data = {f"{large_plaintextfile}": large_entry} + + # Act + # Extract Entries from specified plaintext files + normal_entries = PlaintextToEntries.split_entries_by_max_tokens( + PlaintextToEntries.extract_plaintext_entries(normal_data), + max_tokens=max_tokens, + raw_is_compiled=True, + ) + large_entries = PlaintextToEntries.split_entries_by_max_tokens( + PlaintextToEntries.extract_plaintext_entries(large_data), max_tokens=max_tokens, raw_is_compiled=True + ) + + # Assert + assert len(normal_entries) == 1 + assert len(large_entries) == 2 + + # Helper Functions def create_file(tmp_path: Path, entry=None, filename="test.md"): file_ = tmp_path / filename