Reduce max embedding chunk size to fit token limit of standard bert variants

Reduce embeddings model max prompt size to 128 from 256 words. A word
is usually 3-4 tokens. So 128*4 = 512 should be upper limit to split
text into chunks
This commit is contained in:
Debanjum Singh Solanky 2024-07-07 02:12:37 +05:30
parent 9e31ebff93
commit ed693afd68
10 changed files with 15 additions and 15 deletions

View file

@ -36,7 +36,7 @@ class DocxToEntries(TextToEntries):
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128)
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):

View file

@ -95,7 +95,7 @@ class GithubToEntries(TextToEntries):
)
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=256)
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=128)
return current_entries

View file

@ -37,7 +37,7 @@ class ImageToEntries(TextToEntries):
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128)
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):

View file

@ -30,7 +30,7 @@ class MarkdownToEntries(TextToEntries):
else:
deletion_file_names = None
max_tokens = 256
max_tokens = 128
# Extract Entries from specified Markdown files
with timer("Extract entries from specified Markdown files", logger):
file_to_text_map, current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
@ -56,7 +56,7 @@ class MarkdownToEntries(TextToEntries):
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_markdown_entries(markdown_files, max_tokens=256) -> Tuple[Dict, List[Entry]]:
def extract_markdown_entries(markdown_files, max_tokens=128) -> Tuple[Dict, List[Entry]]:
"Extract entries by heading from specified Markdown files"
entries: List[str] = []
entry_to_file_map: List[Tuple[str, str]] = []
@ -81,7 +81,7 @@ class MarkdownToEntries(TextToEntries):
markdown_file: str,
entries: List[str],
entry_to_file_map: List[Tuple[str, str]],
max_tokens=256,
max_tokens=128,
ancestry: Dict[int, str] = {},
) -> Tuple[List[str], List[Tuple[str, str]]]:
# Prepend the markdown section's heading ancestry

View file

@ -112,7 +112,7 @@ class NotionToEntries(TextToEntries):
page_entries = self.process_page(p_or_d)
current_entries.extend(page_entries)
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=256)
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=128)
return self.update_entries_with_ids(current_entries, user=user)

View file

@ -31,7 +31,7 @@ class OrgToEntries(TextToEntries):
deletion_file_names = None
# Extract Entries from specified Org files
max_tokens = 256
max_tokens = 128
with timer("Extract entries from specified Org files", logger):
file_to_text_map, current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
@ -56,7 +56,7 @@ class OrgToEntries(TextToEntries):
@staticmethod
def extract_org_entries(
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=256
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=128
) -> Tuple[Dict, List[Entry]]:
"Extract entries from specified Org files"
file_to_text_map, entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
@ -90,7 +90,7 @@ class OrgToEntries(TextToEntries):
org_file: str,
entries: List[List[Orgnode]],
entry_to_file_map: List[Tuple[Orgnode, str]],
max_tokens=256,
max_tokens=128,
ancestry: Dict[int, str] = {},
) -> Tuple[List[List[Orgnode]], List[Tuple[Orgnode, str]]]:
"""Parse org_content from org_file into OrgNode entries

View file

@ -39,7 +39,7 @@ class PdfToEntries(TextToEntries):
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128)
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):

View file

@ -36,7 +36,7 @@ class PlaintextToEntries(TextToEntries):
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256, raw_is_compiled=True)
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128, raw_is_compiled=True)
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):

View file

@ -63,7 +63,7 @@ class TextToEntries(ABC):
@staticmethod
def split_entries_by_max_tokens(
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500, raw_is_compiled: bool = False
entries: List[Entry], max_tokens: int = 128, max_word_length: int = 500, raw_is_compiled: bool = False
) -> List[Entry]:
"Split entries if compiled entry length exceeds the max tokens supported by the ML model."
chunked_entries: List[Entry] = []

View file

@ -172,7 +172,7 @@ async def test_text_search(search_config: SearchConfig):
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
# Arrange
# Insert org-mode entry with size exceeding max token limit to new org file
max_tokens = 256
max_tokens = 128
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
with open(new_file_to_index, "w") as f:
f.write(f"* Entry more than {max_tokens} words\n")
@ -224,7 +224,7 @@ conda activate khoj
user=default_user,
)
max_tokens = 256
max_tokens = 128
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
with open(new_file_to_index, "w") as f:
f.write(f"* Entry more than {max_tokens} words\n")