mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 09:25:06 +01:00
Reduce max embedding chunk size to fit token limit of standard bert variants
Reduce embeddings model max prompt size to 128 from 256 words. A word is usually 3-4 tokens. So 128*4 = 512 should be upper limit to split text into chunks
This commit is contained in:
parent
9e31ebff93
commit
ed693afd68
10 changed files with 15 additions and 15 deletions
|
@ -36,7 +36,7 @@ class DocxToEntries(TextToEntries):
|
||||||
|
|
||||||
# Split entries by max tokens supported by model
|
# Split entries by max tokens supported by model
|
||||||
with timer("Split entries by max token size supported by model", logger):
|
with timer("Split entries by max token size supported by model", logger):
|
||||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128)
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
|
|
|
@ -95,7 +95,7 @@ class GithubToEntries(TextToEntries):
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
|
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
|
||||||
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=128)
|
||||||
|
|
||||||
return current_entries
|
return current_entries
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ class ImageToEntries(TextToEntries):
|
||||||
|
|
||||||
# Split entries by max tokens supported by model
|
# Split entries by max tokens supported by model
|
||||||
with timer("Split entries by max token size supported by model", logger):
|
with timer("Split entries by max token size supported by model", logger):
|
||||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128)
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
|
|
|
@ -30,7 +30,7 @@ class MarkdownToEntries(TextToEntries):
|
||||||
else:
|
else:
|
||||||
deletion_file_names = None
|
deletion_file_names = None
|
||||||
|
|
||||||
max_tokens = 256
|
max_tokens = 128
|
||||||
# Extract Entries from specified Markdown files
|
# Extract Entries from specified Markdown files
|
||||||
with timer("Extract entries from specified Markdown files", logger):
|
with timer("Extract entries from specified Markdown files", logger):
|
||||||
file_to_text_map, current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
|
file_to_text_map, current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
|
||||||
|
@ -56,7 +56,7 @@ class MarkdownToEntries(TextToEntries):
|
||||||
return num_new_embeddings, num_deleted_embeddings
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_markdown_entries(markdown_files, max_tokens=256) -> Tuple[Dict, List[Entry]]:
|
def extract_markdown_entries(markdown_files, max_tokens=128) -> Tuple[Dict, List[Entry]]:
|
||||||
"Extract entries by heading from specified Markdown files"
|
"Extract entries by heading from specified Markdown files"
|
||||||
entries: List[str] = []
|
entries: List[str] = []
|
||||||
entry_to_file_map: List[Tuple[str, str]] = []
|
entry_to_file_map: List[Tuple[str, str]] = []
|
||||||
|
@ -81,7 +81,7 @@ class MarkdownToEntries(TextToEntries):
|
||||||
markdown_file: str,
|
markdown_file: str,
|
||||||
entries: List[str],
|
entries: List[str],
|
||||||
entry_to_file_map: List[Tuple[str, str]],
|
entry_to_file_map: List[Tuple[str, str]],
|
||||||
max_tokens=256,
|
max_tokens=128,
|
||||||
ancestry: Dict[int, str] = {},
|
ancestry: Dict[int, str] = {},
|
||||||
) -> Tuple[List[str], List[Tuple[str, str]]]:
|
) -> Tuple[List[str], List[Tuple[str, str]]]:
|
||||||
# Prepend the markdown section's heading ancestry
|
# Prepend the markdown section's heading ancestry
|
||||||
|
|
|
@ -112,7 +112,7 @@ class NotionToEntries(TextToEntries):
|
||||||
page_entries = self.process_page(p_or_d)
|
page_entries = self.process_page(p_or_d)
|
||||||
current_entries.extend(page_entries)
|
current_entries.extend(page_entries)
|
||||||
|
|
||||||
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=128)
|
||||||
|
|
||||||
return self.update_entries_with_ids(current_entries, user=user)
|
return self.update_entries_with_ids(current_entries, user=user)
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ class OrgToEntries(TextToEntries):
|
||||||
deletion_file_names = None
|
deletion_file_names = None
|
||||||
|
|
||||||
# Extract Entries from specified Org files
|
# Extract Entries from specified Org files
|
||||||
max_tokens = 256
|
max_tokens = 128
|
||||||
with timer("Extract entries from specified Org files", logger):
|
with timer("Extract entries from specified Org files", logger):
|
||||||
file_to_text_map, current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
|
file_to_text_map, current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ class OrgToEntries(TextToEntries):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_org_entries(
|
def extract_org_entries(
|
||||||
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=256
|
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=128
|
||||||
) -> Tuple[Dict, List[Entry]]:
|
) -> Tuple[Dict, List[Entry]]:
|
||||||
"Extract entries from specified Org files"
|
"Extract entries from specified Org files"
|
||||||
file_to_text_map, entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
|
file_to_text_map, entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
|
||||||
|
@ -90,7 +90,7 @@ class OrgToEntries(TextToEntries):
|
||||||
org_file: str,
|
org_file: str,
|
||||||
entries: List[List[Orgnode]],
|
entries: List[List[Orgnode]],
|
||||||
entry_to_file_map: List[Tuple[Orgnode, str]],
|
entry_to_file_map: List[Tuple[Orgnode, str]],
|
||||||
max_tokens=256,
|
max_tokens=128,
|
||||||
ancestry: Dict[int, str] = {},
|
ancestry: Dict[int, str] = {},
|
||||||
) -> Tuple[List[List[Orgnode]], List[Tuple[Orgnode, str]]]:
|
) -> Tuple[List[List[Orgnode]], List[Tuple[Orgnode, str]]]:
|
||||||
"""Parse org_content from org_file into OrgNode entries
|
"""Parse org_content from org_file into OrgNode entries
|
||||||
|
|
|
@ -39,7 +39,7 @@ class PdfToEntries(TextToEntries):
|
||||||
|
|
||||||
# Split entries by max tokens supported by model
|
# Split entries by max tokens supported by model
|
||||||
with timer("Split entries by max token size supported by model", logger):
|
with timer("Split entries by max token size supported by model", logger):
|
||||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128)
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
|
|
|
@ -36,7 +36,7 @@ class PlaintextToEntries(TextToEntries):
|
||||||
|
|
||||||
# Split entries by max tokens supported by model
|
# Split entries by max tokens supported by model
|
||||||
with timer("Split entries by max token size supported by model", logger):
|
with timer("Split entries by max token size supported by model", logger):
|
||||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256, raw_is_compiled=True)
|
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=128, raw_is_compiled=True)
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
|
|
|
@ -63,7 +63,7 @@ class TextToEntries(ABC):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def split_entries_by_max_tokens(
|
def split_entries_by_max_tokens(
|
||||||
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500, raw_is_compiled: bool = False
|
entries: List[Entry], max_tokens: int = 128, max_word_length: int = 500, raw_is_compiled: bool = False
|
||||||
) -> List[Entry]:
|
) -> List[Entry]:
|
||||||
"Split entries if compiled entry length exceeds the max tokens supported by the ML model."
|
"Split entries if compiled entry length exceeds the max tokens supported by the ML model."
|
||||||
chunked_entries: List[Entry] = []
|
chunked_entries: List[Entry] = []
|
||||||
|
|
|
@ -172,7 +172,7 @@ async def test_text_search(search_config: SearchConfig):
|
||||||
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
|
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
|
||||||
# Arrange
|
# Arrange
|
||||||
# Insert org-mode entry with size exceeding max token limit to new org file
|
# Insert org-mode entry with size exceeding max token limit to new org file
|
||||||
max_tokens = 256
|
max_tokens = 128
|
||||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||||
with open(new_file_to_index, "w") as f:
|
with open(new_file_to_index, "w") as f:
|
||||||
f.write(f"* Entry more than {max_tokens} words\n")
|
f.write(f"* Entry more than {max_tokens} words\n")
|
||||||
|
@ -224,7 +224,7 @@ conda activate khoj
|
||||||
user=default_user,
|
user=default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
max_tokens = 256
|
max_tokens = 128
|
||||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||||
with open(new_file_to_index, "w") as f:
|
with open(new_file_to_index, "w") as f:
|
||||||
f.write(f"* Entry more than {max_tokens} words\n")
|
f.write(f"* Entry more than {max_tokens} words\n")
|
||||||
|
|
Loading…
Reference in a new issue