Reduce max embedding chunk size to fit token limit of standard bert variants

Reduce embeddings model max prompt size to 128 from 256 words. A word
is usually 3-4 tokens. So 128*4 = 512 should be upper limit to split
text into chunks
This commit is contained in:
Debanjum Singh Solanky 2024-07-07 02:12:37 +05:30
parent 9e31ebff93
commit ed693afd68
10 changed files with 15 additions and 15 deletions

View file

@ -36,7 +36,7 @@ class DocxToEntries(TextToEntries):
# Split entries by max tokens supported by model # Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger): with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128)
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):

View file

@ -95,7 +95,7 @@ class GithubToEntries(TextToEntries):
) )
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger): with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=256) current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=128)
return current_entries return current_entries

View file

@ -37,7 +37,7 @@ class ImageToEntries(TextToEntries):
# Split entries by max tokens supported by model # Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger): with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128)
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):

View file

@ -30,7 +30,7 @@ class MarkdownToEntries(TextToEntries):
else: else:
deletion_file_names = None deletion_file_names = None
max_tokens = 256 max_tokens = 128
# Extract Entries from specified Markdown files # Extract Entries from specified Markdown files
with timer("Extract entries from specified Markdown files", logger): with timer("Extract entries from specified Markdown files", logger):
file_to_text_map, current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens) file_to_text_map, current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
@ -56,7 +56,7 @@ class MarkdownToEntries(TextToEntries):
return num_new_embeddings, num_deleted_embeddings return num_new_embeddings, num_deleted_embeddings
@staticmethod @staticmethod
def extract_markdown_entries(markdown_files, max_tokens=256) -> Tuple[Dict, List[Entry]]: def extract_markdown_entries(markdown_files, max_tokens=128) -> Tuple[Dict, List[Entry]]:
"Extract entries by heading from specified Markdown files" "Extract entries by heading from specified Markdown files"
entries: List[str] = [] entries: List[str] = []
entry_to_file_map: List[Tuple[str, str]] = [] entry_to_file_map: List[Tuple[str, str]] = []
@ -81,7 +81,7 @@ class MarkdownToEntries(TextToEntries):
markdown_file: str, markdown_file: str,
entries: List[str], entries: List[str],
entry_to_file_map: List[Tuple[str, str]], entry_to_file_map: List[Tuple[str, str]],
max_tokens=256, max_tokens=128,
ancestry: Dict[int, str] = {}, ancestry: Dict[int, str] = {},
) -> Tuple[List[str], List[Tuple[str, str]]]: ) -> Tuple[List[str], List[Tuple[str, str]]]:
# Prepend the markdown section's heading ancestry # Prepend the markdown section's heading ancestry

View file

@ -112,7 +112,7 @@ class NotionToEntries(TextToEntries):
page_entries = self.process_page(p_or_d) page_entries = self.process_page(p_or_d)
current_entries.extend(page_entries) current_entries.extend(page_entries)
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=256) current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=128)
return self.update_entries_with_ids(current_entries, user=user) return self.update_entries_with_ids(current_entries, user=user)

View file

@ -31,7 +31,7 @@ class OrgToEntries(TextToEntries):
deletion_file_names = None deletion_file_names = None
# Extract Entries from specified Org files # Extract Entries from specified Org files
max_tokens = 256 max_tokens = 128
with timer("Extract entries from specified Org files", logger): with timer("Extract entries from specified Org files", logger):
file_to_text_map, current_entries = self.extract_org_entries(files, max_tokens=max_tokens) file_to_text_map, current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
@ -56,7 +56,7 @@ class OrgToEntries(TextToEntries):
@staticmethod @staticmethod
def extract_org_entries( def extract_org_entries(
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=256 org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=128
) -> Tuple[Dict, List[Entry]]: ) -> Tuple[Dict, List[Entry]]:
"Extract entries from specified Org files" "Extract entries from specified Org files"
file_to_text_map, entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens) file_to_text_map, entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
@ -90,7 +90,7 @@ class OrgToEntries(TextToEntries):
org_file: str, org_file: str,
entries: List[List[Orgnode]], entries: List[List[Orgnode]],
entry_to_file_map: List[Tuple[Orgnode, str]], entry_to_file_map: List[Tuple[Orgnode, str]],
max_tokens=256, max_tokens=128,
ancestry: Dict[int, str] = {}, ancestry: Dict[int, str] = {},
) -> Tuple[List[List[Orgnode]], List[Tuple[Orgnode, str]]]: ) -> Tuple[List[List[Orgnode]], List[Tuple[Orgnode, str]]]:
"""Parse org_content from org_file into OrgNode entries """Parse org_content from org_file into OrgNode entries

View file

@ -39,7 +39,7 @@ class PdfToEntries(TextToEntries):
# Split entries by max tokens supported by model # Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger): with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128)
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):

View file

@ -36,7 +36,7 @@ class PlaintextToEntries(TextToEntries):
# Split entries by max tokens supported by model # Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger): with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256, raw_is_compiled=True) current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128, raw_is_compiled=True)
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):

View file

@ -63,7 +63,7 @@ class TextToEntries(ABC):
@staticmethod @staticmethod
def split_entries_by_max_tokens( def split_entries_by_max_tokens(
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500, raw_is_compiled: bool = False entries: List[Entry], max_tokens: int = 128, max_word_length: int = 500, raw_is_compiled: bool = False
) -> List[Entry]: ) -> List[Entry]:
"Split entries if compiled entry length exceeds the max tokens supported by the ML model." "Split entries if compiled entry length exceeds the max tokens supported by the ML model."
chunked_entries: List[Entry] = [] chunked_entries: List[Entry] = []

View file

@ -172,7 +172,7 @@ async def test_text_search(search_config: SearchConfig):
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog): def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
# Arrange # Arrange
# Insert org-mode entry with size exceeding max token limit to new org file # Insert org-mode entry with size exceeding max token limit to new org file
max_tokens = 256 max_tokens = 128
new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
with open(new_file_to_index, "w") as f: with open(new_file_to_index, "w") as f:
f.write(f"* Entry more than {max_tokens} words\n") f.write(f"* Entry more than {max_tokens} words\n")
@ -224,7 +224,7 @@ conda activate khoj
user=default_user, user=default_user,
) )
max_tokens = 256 max_tokens = 128
new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
with open(new_file_to_index, "w") as f: with open(new_file_to_index, "w") as f:
f.write(f"* Entry more than {max_tokens} words\n") f.write(f"* Entry more than {max_tokens} words\n")