mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Reduce max embedding chunk size to fit token limit of standard bert variants
Reduce embeddings model max prompt size to 128 from 256 words. A word is usually 3-4 tokens. So 128*4 = 512 should be upper limit to split text into chunks
This commit is contained in:
parent
9e31ebff93
commit
ed693afd68
10 changed files with 15 additions and 15 deletions
|
@ -36,7 +36,7 @@ class DocxToEntries(TextToEntries):
|
|||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128)
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
|
|
|
@ -95,7 +95,7 @@ class GithubToEntries(TextToEntries):
|
|||
)
|
||||
|
||||
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
|
||||
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=128)
|
||||
|
||||
return current_entries
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class ImageToEntries(TextToEntries):
|
|||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128)
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
|
|
|
@ -30,7 +30,7 @@ class MarkdownToEntries(TextToEntries):
|
|||
else:
|
||||
deletion_file_names = None
|
||||
|
||||
max_tokens = 256
|
||||
max_tokens = 128
|
||||
# Extract Entries from specified Markdown files
|
||||
with timer("Extract entries from specified Markdown files", logger):
|
||||
file_to_text_map, current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
|
||||
|
@ -56,7 +56,7 @@ class MarkdownToEntries(TextToEntries):
|
|||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def extract_markdown_entries(markdown_files, max_tokens=256) -> Tuple[Dict, List[Entry]]:
|
||||
def extract_markdown_entries(markdown_files, max_tokens=128) -> Tuple[Dict, List[Entry]]:
|
||||
"Extract entries by heading from specified Markdown files"
|
||||
entries: List[str] = []
|
||||
entry_to_file_map: List[Tuple[str, str]] = []
|
||||
|
@ -81,7 +81,7 @@ class MarkdownToEntries(TextToEntries):
|
|||
markdown_file: str,
|
||||
entries: List[str],
|
||||
entry_to_file_map: List[Tuple[str, str]],
|
||||
max_tokens=256,
|
||||
max_tokens=128,
|
||||
ancestry: Dict[int, str] = {},
|
||||
) -> Tuple[List[str], List[Tuple[str, str]]]:
|
||||
# Prepend the markdown section's heading ancestry
|
||||
|
|
|
@ -112,7 +112,7 @@ class NotionToEntries(TextToEntries):
|
|||
page_entries = self.process_page(p_or_d)
|
||||
current_entries.extend(page_entries)
|
||||
|
||||
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=128)
|
||||
|
||||
return self.update_entries_with_ids(current_entries, user=user)
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ class OrgToEntries(TextToEntries):
|
|||
deletion_file_names = None
|
||||
|
||||
# Extract Entries from specified Org files
|
||||
max_tokens = 256
|
||||
max_tokens = 128
|
||||
with timer("Extract entries from specified Org files", logger):
|
||||
file_to_text_map, current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
|
||||
|
||||
|
@ -56,7 +56,7 @@ class OrgToEntries(TextToEntries):
|
|||
|
||||
@staticmethod
|
||||
def extract_org_entries(
|
||||
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=256
|
||||
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=128
|
||||
) -> Tuple[Dict, List[Entry]]:
|
||||
"Extract entries from specified Org files"
|
||||
file_to_text_map, entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
|
||||
|
@ -90,7 +90,7 @@ class OrgToEntries(TextToEntries):
|
|||
org_file: str,
|
||||
entries: List[List[Orgnode]],
|
||||
entry_to_file_map: List[Tuple[Orgnode, str]],
|
||||
max_tokens=256,
|
||||
max_tokens=128,
|
||||
ancestry: Dict[int, str] = {},
|
||||
) -> Tuple[List[List[Orgnode]], List[Tuple[Orgnode, str]]]:
|
||||
"""Parse org_content from org_file into OrgNode entries
|
||||
|
|
|
@ -39,7 +39,7 @@ class PdfToEntries(TextToEntries):
|
|||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128)
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
|
|
|
@ -36,7 +36,7 @@ class PlaintextToEntries(TextToEntries):
|
|||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256, raw_is_compiled=True)
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128, raw_is_compiled=True)
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
|
|
|
@ -63,7 +63,7 @@ class TextToEntries(ABC):
|
|||
|
||||
@staticmethod
|
||||
def split_entries_by_max_tokens(
|
||||
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500, raw_is_compiled: bool = False
|
||||
entries: List[Entry], max_tokens: int = 128, max_word_length: int = 500, raw_is_compiled: bool = False
|
||||
) -> List[Entry]:
|
||||
"Split entries if compiled entry length exceeds the max tokens supported by the ML model."
|
||||
chunked_entries: List[Entry] = []
|
||||
|
|
|
@ -172,7 +172,7 @@ async def test_text_search(search_config: SearchConfig):
|
|||
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
# Insert org-mode entry with size exceeding max token limit to new org file
|
||||
max_tokens = 256
|
||||
max_tokens = 128
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"* Entry more than {max_tokens} words\n")
|
||||
|
@ -224,7 +224,7 @@ conda activate khoj
|
|||
user=default_user,
|
||||
)
|
||||
|
||||
max_tokens = 256
|
||||
max_tokens = 128
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"* Entry more than {max_tokens} words\n")
|
||||
|
|
Loading…
Reference in a new issue