Split text entries by max tokens supported by ML models

### Background
There is a limit to the maximum input tokens (words) that an ML model can encode into an embedding vector.
For the models used for text search in khoj, a max token size of 256 words is appropriate [1](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1#:~:text=model%20was%20just%20trained%20on%20input%20text%20up%20to%20250%20word%20pieces),[2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2#:~:text=input%20text%20longer%20than%20256%20word%20pieces%20is%20truncated)

### Issue
Until now entries exceeding max token size would silently get truncated during embedding generation.
So the truncated portion of the entries would be ignored when matching queries with entries
This would degrade the quality of the results

### Fix
- e057c8e Add method to split entries by specified max tokens limit
- Split entries by max tokens while converting [Org](https://github.com/debanjum/khoj/commit/c79919b), [Markdown](https://github.com/debanjum/khoj/commit/f209e30) and [Beancount](https://github.com/debanjum/khoj/commit/17fa123) entries to JSONL
- b283650 Deduplicate results for user query by raw text before returning results

### Results
- The quality of the search results should improve
- Relevant, long entries should show up in results more often
This commit is contained in:
Debanjum 2022-12-26 18:23:43 +00:00 committed by GitHub
commit 06c25682c9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 102 additions and 3 deletions

View file

@ -35,6 +35,12 @@ class BeancountToJsonl(TextToJsonl):
end = time.time() end = time.time()
logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds") logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds")
# Split entries by max tokens supported by model
start = time.time()
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
end = time.time()
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
start = time.time() start = time.time()
if not previous_entries: if not previous_entries:

View file

@ -35,6 +35,12 @@ class MarkdownToJsonl(TextToJsonl):
end = time.time() end = time.time()
logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds") logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds")
# Split entries by max tokens supported by model
start = time.time()
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
end = time.time()
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
start = time.time() start = time.time()
if not previous_entries: if not previous_entries:

View file

@ -41,7 +41,12 @@ class OrgToJsonl(TextToJsonl):
start = time.time() start = time.time()
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries) current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
end = time.time() end = time.time()
logger.debug(f"Convert OrgNodes into entry dictionaries: {end - start} seconds") logger.debug(f"Convert OrgNodes into list of entries: {end - start} seconds")
start = time.time()
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
end = time.time()
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
if not previous_entries: if not previous_entries:

View file

@ -23,6 +23,19 @@ class TextToJsonl(ABC):
def hash_func(key: str) -> Callable: def hash_func(key: str) -> Callable:
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest() return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest()
@staticmethod
def split_entries_by_max_tokens(entries: list[Entry], max_tokens: int=256) -> list[Entry]:
"Split entries if compiled entry length exceeds the max tokens supported by the ML model."
chunked_entries: list[Entry] = []
for entry in entries:
compiled_entry_words = entry.compiled.split()
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
compiled_entry_words_chunk = compiled_entry_words[chunk_index:chunk_index + max_tokens]
compiled_entry_chunk = ' '.join(compiled_entry_words_chunk)
entry_chunk = Entry(compiled=compiled_entry_chunk, raw=entry.raw, file=entry.file)
chunked_entries.append(entry_chunk)
return chunked_entries
def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]: def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]:
# Hash all current and previous entries to identify new entries # Hash all current and previous entries to identify new entries
start = time.time() start = time.time()

View file

@ -150,6 +150,17 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
end = time.time() end = time.time()
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}") logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
# Deduplicate entries by raw entry text before showing to users
# Compiled entries are split by max tokens supported by ML models.
# This can result in duplicate hits, entries shown to user.
start = time.time()
seen, original_hits_count = set(), len(hits)
hits = [hit for hit in hits
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
duplicate_hits = original_hits_count - len(hits)
end = time.time()
logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates")
return hits, entries return hits, entries

View file

@ -3,7 +3,6 @@ import pytest
# Internal Packages # Internal Packages
from src.search_type import image_search, text_search from src.search_type import image_search, text_search
from src.utils.config import SearchType
from src.utils.helpers import resolve_absolute_path from src.utils.helpers import resolve_absolute_path
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
from src.processor.org_mode.org_to_jsonl import OrgToJsonl from src.processor.org_mode.org_to_jsonl import OrgToJsonl

View file

@ -3,6 +3,7 @@ import json
# Internal Packages # Internal Packages
from src.processor.org_mode.org_to_jsonl import OrgToJsonl from src.processor.org_mode.org_to_jsonl import OrgToJsonl
from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import is_none_or_empty from src.utils.helpers import is_none_or_empty
@ -35,6 +36,31 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
assert is_none_or_empty(jsonl_data) assert is_none_or_empty(jsonl_data)
def test_entry_split_when_exceeds_max_words(tmp_path):
"Ensure entries with compiled words exceeding max_words are split."
# Arrange
entry = f'''*** Heading
\t\r
Body Line 1
'''
orgfile = create_file(tmp_path, entry)
# Act
# Extract Entries from specified Org files
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile])
# Split each entry from specified Org files by max words
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
TextToJsonl.split_entries_by_max_tokens(
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map),
max_tokens = 2)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
assert len(jsonl_data) == 2
def test_entry_with_body_to_jsonl(tmp_path): def test_entry_with_body_to_jsonl(tmp_path):
"Ensure entries with valid body text are loaded." "Ensure entries with valid body text are loaded."
# Arrange # Arrange

View file

@ -80,10 +80,43 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
assert "git clone" in search_result assert "git clone" in search_result
# ----------------------------------------------------------------------------------------------------
def test_entry_chunking_by_max_tokens(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
initial_notes_model= text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
file_to_add_on_reload = Path(content_config.org.input_filter[0]).parent / "entry_exceeding_max_tokens.org"
content_config.org.input_files = [f'{file_to_add_on_reload}']
# Insert org-mode entry with size exceeding max token limit to new org file
max_tokens = 256
with open(file_to_add_on_reload, "w") as f:
f.write(f"* Entry more than {max_tokens} words\n")
for index in range(max_tokens+1):
f.write(f"{index} ")
# Act
# reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
# Assert
# verify newly added org-mode entry is split by max tokens
assert len(initial_notes_model.entries) == 12
assert len(initial_notes_model.corpus_embeddings) == 12
# Cleanup
# delete reload test file added
content_config.org.input_files = []
file_to_add_on_reload.unlink()
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig): def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
initial_notes_model= text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) initial_notes_model= text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10 assert len(initial_notes_model.corpus_embeddings) == 10