mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 09:25:06 +01:00
Chunk text in preference order of para, sentence, word, character
- Previous simplistic chunking strategy of splitting text by space didn't capture notes with newlines, no spaces. For e.g in #620 - New strategy will try chunk the text at more natural points like paragraph, sentence, word first. If none of those work it'll split at character to fit within max token limit - Drop long words while preserving original delimiters Resolves #620
This commit is contained in:
parent
a627f56a64
commit
86575b2946
3 changed files with 46 additions and 17 deletions
|
@ -1,10 +1,12 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import repeat
|
||||
from typing import Any, Callable, List, Set, Tuple
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from tqdm import tqdm
|
||||
|
||||
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
||||
|
@ -34,6 +36,27 @@ class TextToEntries(ABC):
|
|||
def hash_func(key: str) -> Callable:
|
||||
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def remove_long_words(text: str, max_word_length: int = 500) -> str:
|
||||
"Remove words longer than max_word_length from text."
|
||||
# Split the string by words, keeping the delimiters
|
||||
splits = re.split(r"(\s+)", text) + [""]
|
||||
words_with_delimiters = list(zip(splits[::2], splits[1::2]))
|
||||
|
||||
# Filter out long words while preserving delimiters in text
|
||||
filtered_text = [
|
||||
f"{word}{delimiter}"
|
||||
for word, delimiter in words_with_delimiters
|
||||
if not word.strip() or len(word.strip()) <= max_word_length
|
||||
]
|
||||
|
||||
return "".join(filtered_text)
|
||||
|
||||
@staticmethod
|
||||
def tokenizer(text: str) -> List[str]:
|
||||
"Tokenize text into words."
|
||||
return text.split()
|
||||
|
||||
@staticmethod
|
||||
def split_entries_by_max_tokens(
|
||||
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500
|
||||
|
@ -44,24 +67,30 @@ class TextToEntries(ABC):
|
|||
if is_none_or_empty(entry.compiled):
|
||||
continue
|
||||
|
||||
# Split entry into words
|
||||
compiled_entry_words = [word for word in entry.compiled.split(" ") if word != ""]
|
||||
|
||||
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
|
||||
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
|
||||
# Split entry into chunks of max_tokens
|
||||
# Use chunking preference order: paragraphs > sentences > words > characters
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=max_tokens,
|
||||
separators=["\n\n", "\n", "!", "?", ".", " ", "\t", ""],
|
||||
keep_separator=True,
|
||||
length_function=lambda chunk: len(TextToEntries.tokenizer(chunk)),
|
||||
chunk_overlap=0,
|
||||
)
|
||||
chunked_entry_chunks = text_splitter.split_text(entry.compiled)
|
||||
corpus_id = uuid.uuid4()
|
||||
|
||||
# Split entry into chunks of max tokens
|
||||
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
|
||||
compiled_entry_words_chunk = compiled_entry_words[chunk_index : chunk_index + max_tokens]
|
||||
compiled_entry_chunk = " ".join(compiled_entry_words_chunk)
|
||||
|
||||
# Create heading prefixed entry from each chunk
|
||||
for chunk_index, compiled_entry_chunk in enumerate(chunked_entry_chunks):
|
||||
# Prepend heading to all other chunks, the first chunk already has heading from original entry
|
||||
if chunk_index > 0:
|
||||
if chunk_index > 0 and entry.heading:
|
||||
# Snip heading to avoid crossing max_tokens limit
|
||||
# Keep last 100 characters of heading as entry heading more important than filename
|
||||
snipped_heading = entry.heading[-100:]
|
||||
compiled_entry_chunk = f"{snipped_heading}.\n{compiled_entry_chunk}"
|
||||
# Prepend snipped heading
|
||||
compiled_entry_chunk = f"{snipped_heading}\n{compiled_entry_chunk}"
|
||||
|
||||
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
|
||||
compiled_entry_chunk = TextToEntries.remove_long_words(compiled_entry_chunk, max_word_length)
|
||||
|
||||
# Clean entry of unwanted characters like \0 character
|
||||
compiled_entry_chunk = TextToEntries.clean_field(compiled_entry_chunk)
|
||||
|
|
|
@ -54,12 +54,12 @@ def test_entry_split_when_exceeds_max_words():
|
|||
# Extract Entries from specified Org files
|
||||
entries = OrgToEntries.extract_org_entries(org_files=data)
|
||||
|
||||
# Split each entry from specified Org files by max words
|
||||
entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=4)
|
||||
# Split each entry from specified Org files by max tokens
|
||||
entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=6)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
# Ensure compiled entries split by max_words start with entry heading (for search context)
|
||||
# Ensure compiled entries split by max tokens start with entry heading (for search context)
|
||||
assert all([entry.compiled.startswith(expected_heading) for entry in entries])
|
||||
|
||||
|
||||
|
|
|
@ -192,7 +192,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
|
|||
|
||||
# Assert
|
||||
assert (
|
||||
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
|
||||
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
|
||||
), "new entry not split by max tokens"
|
||||
|
||||
|
||||
|
@ -250,7 +250,7 @@ conda activate khoj
|
|||
|
||||
# Assert
|
||||
assert (
|
||||
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
|
||||
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
|
||||
), "new entry not split by max tokens"
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue