diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py index 8ebc6604..8b7df3e9 100644 --- a/src/khoj/processor/content/text_to_entries.py +++ b/src/khoj/processor/content/text_to_entries.py @@ -1,10 +1,12 @@ import hashlib import logging +import re import uuid from abc import ABC, abstractmethod from itertools import repeat from typing import Any, Callable, List, Set, Tuple +from langchain.text_splitter import RecursiveCharacterTextSplitter from tqdm import tqdm from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default @@ -34,6 +36,27 @@ class TextToEntries(ABC): def hash_func(key: str) -> Callable: return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest() + @staticmethod + def remove_long_words(text: str, max_word_length: int = 500) -> str: + "Remove words longer than max_word_length from text." + # Split the string by words, keeping the delimiters + splits = re.split(r"(\s+)", text) + [""] + words_with_delimiters = list(zip(splits[::2], splits[1::2])) + + # Filter out long words while preserving delimiters in text + filtered_text = [ + f"{word}{delimiter}" + for word, delimiter in words_with_delimiters + if not word.strip() or len(word.strip()) <= max_word_length + ] + + return "".join(filtered_text) + + @staticmethod + def tokenizer(text: str) -> List[str]: + "Tokenize text into words." + return text.split() + @staticmethod def split_entries_by_max_tokens( entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500 @@ -44,24 +67,30 @@ class TextToEntries(ABC): if is_none_or_empty(entry.compiled): continue - # Split entry into words - compiled_entry_words = [word for word in entry.compiled.split(" ") if word != ""] - - # Drop long words instead of having entry truncated to maintain quality of entry processed by models - compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length] + # Split entry into chunks of max_tokens + # Use chunking preference order: paragraphs > sentences > words > characters + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=max_tokens, + separators=["\n\n", "\n", "!", "?", ".", " ", "\t", ""], + keep_separator=True, + length_function=lambda chunk: len(TextToEntries.tokenizer(chunk)), + chunk_overlap=0, + ) + chunked_entry_chunks = text_splitter.split_text(entry.compiled) corpus_id = uuid.uuid4() - # Split entry into chunks of max tokens - for chunk_index in range(0, len(compiled_entry_words), max_tokens): - compiled_entry_words_chunk = compiled_entry_words[chunk_index : chunk_index + max_tokens] - compiled_entry_chunk = " ".join(compiled_entry_words_chunk) - + # Create heading prefixed entry from each chunk + for chunk_index, compiled_entry_chunk in enumerate(chunked_entry_chunks): # Prepend heading to all other chunks, the first chunk already has heading from original entry - if chunk_index > 0: + if chunk_index > 0 and entry.heading: # Snip heading to avoid crossing max_tokens limit # Keep last 100 characters of heading as entry heading more important than filename snipped_heading = entry.heading[-100:] - compiled_entry_chunk = f"{snipped_heading}.\n{compiled_entry_chunk}" + # Prepend snipped heading + compiled_entry_chunk = f"{snipped_heading}\n{compiled_entry_chunk}" + + # Drop long words instead of having entry truncated to maintain quality of entry processed by models + compiled_entry_chunk = TextToEntries.remove_long_words(compiled_entry_chunk, max_word_length) # Clean entry of unwanted characters like \0 character compiled_entry_chunk = TextToEntries.clean_field(compiled_entry_chunk) diff --git a/tests/test_org_to_entries.py b/tests/test_org_to_entries.py index 66371e5c..b97e62e9 100644 --- a/tests/test_org_to_entries.py +++ b/tests/test_org_to_entries.py @@ -54,12 +54,12 @@ def test_entry_split_when_exceeds_max_words(): # Extract Entries from specified Org files entries = OrgToEntries.extract_org_entries(org_files=data) - # Split each entry from specified Org files by max words - entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=4) + # Split each entry from specified Org files by max tokens + entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=6) # Assert assert len(entries) == 2 - # Ensure compiled entries split by max_words start with entry heading (for search context) + # Ensure compiled entries split by max tokens start with entry heading (for search context) assert all([entry.compiled.startswith(expected_heading) for entry in entries]) diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 791ce91b..9a41430f 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -192,7 +192,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon # Assert assert ( - "Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message + "Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message ), "new entry not split by max tokens" @@ -250,7 +250,7 @@ conda activate khoj # Assert assert ( - "Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message + "Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message ), "new entry not split by max tokens"