Chunk text in preference order of para, sentence, word, character

- Previous simplistic chunking strategy of splitting text by space
  didn't capture notes with newlines, no spaces. For e.g in #620

- New strategy will try chunk the text at more natural points like
  paragraph, sentence, word first. If none of those work it'll split
  at character to fit within max token limit

- Drop long words while preserving original delimiters

Resolves #620
This commit is contained in:
Debanjum Singh Solanky 2024-01-29 05:03:29 +05:30
parent a627f56a64
commit 86575b2946
3 changed files with 46 additions and 17 deletions

View file

@ -1,10 +1,12 @@
import hashlib
import logging
import re
import uuid
from abc import ABC, abstractmethod
from itertools import repeat
from typing import Any, Callable, List, Set, Tuple
from langchain.text_splitter import RecursiveCharacterTextSplitter
from tqdm import tqdm
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
@ -34,6 +36,27 @@ class TextToEntries(ABC):
def hash_func(key: str) -> Callable:
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest()
@staticmethod
def remove_long_words(text: str, max_word_length: int = 500) -> str:
"Remove words longer than max_word_length from text."
# Split the string by words, keeping the delimiters
splits = re.split(r"(\s+)", text) + [""]
words_with_delimiters = list(zip(splits[::2], splits[1::2]))
# Filter out long words while preserving delimiters in text
filtered_text = [
f"{word}{delimiter}"
for word, delimiter in words_with_delimiters
if not word.strip() or len(word.strip()) <= max_word_length
]
return "".join(filtered_text)
@staticmethod
def tokenizer(text: str) -> List[str]:
"Tokenize text into words."
return text.split()
@staticmethod
def split_entries_by_max_tokens(
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500
@ -44,24 +67,30 @@ class TextToEntries(ABC):
if is_none_or_empty(entry.compiled):
continue
# Split entry into words
compiled_entry_words = [word for word in entry.compiled.split(" ") if word != ""]
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
# Split entry into chunks of max_tokens
# Use chunking preference order: paragraphs > sentences > words > characters
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=max_tokens,
separators=["\n\n", "\n", "!", "?", ".", " ", "\t", ""],
keep_separator=True,
length_function=lambda chunk: len(TextToEntries.tokenizer(chunk)),
chunk_overlap=0,
)
chunked_entry_chunks = text_splitter.split_text(entry.compiled)
corpus_id = uuid.uuid4()
# Split entry into chunks of max tokens
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
compiled_entry_words_chunk = compiled_entry_words[chunk_index : chunk_index + max_tokens]
compiled_entry_chunk = " ".join(compiled_entry_words_chunk)
# Create heading prefixed entry from each chunk
for chunk_index, compiled_entry_chunk in enumerate(chunked_entry_chunks):
# Prepend heading to all other chunks, the first chunk already has heading from original entry
if chunk_index > 0:
if chunk_index > 0 and entry.heading:
# Snip heading to avoid crossing max_tokens limit
# Keep last 100 characters of heading as entry heading more important than filename
snipped_heading = entry.heading[-100:]
compiled_entry_chunk = f"{snipped_heading}.\n{compiled_entry_chunk}"
# Prepend snipped heading
compiled_entry_chunk = f"{snipped_heading}\n{compiled_entry_chunk}"
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
compiled_entry_chunk = TextToEntries.remove_long_words(compiled_entry_chunk, max_word_length)
# Clean entry of unwanted characters like \0 character
compiled_entry_chunk = TextToEntries.clean_field(compiled_entry_chunk)

View file

@ -54,12 +54,12 @@ def test_entry_split_when_exceeds_max_words():
# Extract Entries from specified Org files
entries = OrgToEntries.extract_org_entries(org_files=data)
# Split each entry from specified Org files by max words
entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=4)
# Split each entry from specified Org files by max tokens
entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=6)
# Assert
assert len(entries) == 2
# Ensure compiled entries split by max_words start with entry heading (for search context)
# Ensure compiled entries split by max tokens start with entry heading (for search context)
assert all([entry.compiled.startswith(expected_heading) for entry in entries])

View file

@ -192,7 +192,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
# Assert
assert (
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
), "new entry not split by max tokens"
@ -250,7 +250,7 @@ conda activate khoj
# Assert
assert (
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
), "new entry not split by max tokens"