mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Split text entries by max tokens supported by ML models
### Background There is a limit to the maximum input tokens (words) that an ML model can encode into an embedding vector. For the models used for text search in khoj, a max token size of 256 words is appropriate [1](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1#:~:text=model%20was%20just%20trained%20on%20input%20text%20up%20to%20250%20word%20pieces),[2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2#:~:text=input%20text%20longer%20than%20256%20word%20pieces%20is%20truncated) ### Issue Until now entries exceeding max token size would silently get truncated during embedding generation. So the truncated portion of the entries would be ignored when matching queries with entries This would degrade the quality of the results ### Fix -e057c8e
Add method to split entries by specified max tokens limit - Split entries by max tokens while converting [Org](https://github.com/debanjum/khoj/commit/c79919b), [Markdown](https://github.com/debanjum/khoj/commit/f209e30) and [Beancount](https://github.com/debanjum/khoj/commit/17fa123) entries to JSONL -b283650
Deduplicate results for user query by raw text before returning results ### Results - The quality of the search results should improve - Relevant, long entries should show up in results more often
This commit is contained in:
commit
06c25682c9
8 changed files with 102 additions and 3 deletions
|
@ -35,6 +35,12 @@ class BeancountToJsonl(TextToJsonl):
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds")
|
logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds")
|
||||||
|
|
||||||
|
# Split entries by max tokens supported by model
|
||||||
|
start = time.time()
|
||||||
|
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||||
|
end = time.time()
|
||||||
|
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
start = time.time()
|
start = time.time()
|
||||||
if not previous_entries:
|
if not previous_entries:
|
||||||
|
|
|
@ -35,6 +35,12 @@ class MarkdownToJsonl(TextToJsonl):
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds")
|
logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds")
|
||||||
|
|
||||||
|
# Split entries by max tokens supported by model
|
||||||
|
start = time.time()
|
||||||
|
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||||
|
end = time.time()
|
||||||
|
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
start = time.time()
|
start = time.time()
|
||||||
if not previous_entries:
|
if not previous_entries:
|
||||||
|
|
|
@ -41,7 +41,12 @@ class OrgToJsonl(TextToJsonl):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
|
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.debug(f"Convert OrgNodes into entry dictionaries: {end - start} seconds")
|
logger.debug(f"Convert OrgNodes into list of entries: {end - start} seconds")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||||
|
end = time.time()
|
||||||
|
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
if not previous_entries:
|
if not previous_entries:
|
||||||
|
|
|
@ -23,6 +23,19 @@ class TextToJsonl(ABC):
|
||||||
def hash_func(key: str) -> Callable:
|
def hash_func(key: str) -> Callable:
|
||||||
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest()
|
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_entries_by_max_tokens(entries: list[Entry], max_tokens: int=256) -> list[Entry]:
|
||||||
|
"Split entries if compiled entry length exceeds the max tokens supported by the ML model."
|
||||||
|
chunked_entries: list[Entry] = []
|
||||||
|
for entry in entries:
|
||||||
|
compiled_entry_words = entry.compiled.split()
|
||||||
|
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
|
||||||
|
compiled_entry_words_chunk = compiled_entry_words[chunk_index:chunk_index + max_tokens]
|
||||||
|
compiled_entry_chunk = ' '.join(compiled_entry_words_chunk)
|
||||||
|
entry_chunk = Entry(compiled=compiled_entry_chunk, raw=entry.raw, file=entry.file)
|
||||||
|
chunked_entries.append(entry_chunk)
|
||||||
|
return chunked_entries
|
||||||
|
|
||||||
def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]:
|
def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]:
|
||||||
# Hash all current and previous entries to identify new entries
|
# Hash all current and previous entries to identify new entries
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
|
@ -150,6 +150,17 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
|
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
|
||||||
|
|
||||||
|
# Deduplicate entries by raw entry text before showing to users
|
||||||
|
# Compiled entries are split by max tokens supported by ML models.
|
||||||
|
# This can result in duplicate hits, entries shown to user.
|
||||||
|
start = time.time()
|
||||||
|
seen, original_hits_count = set(), len(hits)
|
||||||
|
hits = [hit for hit in hits
|
||||||
|
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
|
||||||
|
duplicate_hits = original_hits_count - len(hits)
|
||||||
|
end = time.time()
|
||||||
|
logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates")
|
||||||
|
|
||||||
return hits, entries
|
return hits, entries
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,6 @@ import pytest
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.search_type import image_search, text_search
|
from src.search_type import image_search, text_search
|
||||||
from src.utils.config import SearchType
|
|
||||||
from src.utils.helpers import resolve_absolute_path
|
from src.utils.helpers import resolve_absolute_path
|
||||||
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
|
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
|
||||||
from src.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from src.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
|
|
|
@ -3,6 +3,7 @@ import json
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from src.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
|
from src.processor.text_to_jsonl import TextToJsonl
|
||||||
from src.utils.helpers import is_none_or_empty
|
from src.utils.helpers import is_none_or_empty
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,6 +36,31 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
|
||||||
assert is_none_or_empty(jsonl_data)
|
assert is_none_or_empty(jsonl_data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_entry_split_when_exceeds_max_words(tmp_path):
|
||||||
|
"Ensure entries with compiled words exceeding max_words are split."
|
||||||
|
# Arrange
|
||||||
|
entry = f'''*** Heading
|
||||||
|
\t\r
|
||||||
|
Body Line 1
|
||||||
|
'''
|
||||||
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Extract Entries from specified Org files
|
||||||
|
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile])
|
||||||
|
|
||||||
|
# Split each entry from specified Org files by max words
|
||||||
|
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
|
||||||
|
TextToJsonl.split_entries_by_max_tokens(
|
||||||
|
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map),
|
||||||
|
max_tokens = 2)
|
||||||
|
)
|
||||||
|
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(jsonl_data) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_entry_with_body_to_jsonl(tmp_path):
|
def test_entry_with_body_to_jsonl(tmp_path):
|
||||||
"Ensure entries with valid body text are loaded."
|
"Ensure entries with valid body text are loaded."
|
||||||
# Arrange
|
# Arrange
|
||||||
|
|
|
@ -80,10 +80,43 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
|
||||||
assert "git clone" in search_result
|
assert "git clone" in search_result
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def test_entry_chunking_by_max_tokens(content_config: ContentConfig, search_config: SearchConfig):
|
||||||
|
# Arrange
|
||||||
|
initial_notes_model= text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||||
|
|
||||||
|
assert len(initial_notes_model.entries) == 10
|
||||||
|
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||||
|
|
||||||
|
file_to_add_on_reload = Path(content_config.org.input_filter[0]).parent / "entry_exceeding_max_tokens.org"
|
||||||
|
content_config.org.input_files = [f'{file_to_add_on_reload}']
|
||||||
|
|
||||||
|
# Insert org-mode entry with size exceeding max token limit to new org file
|
||||||
|
max_tokens = 256
|
||||||
|
with open(file_to_add_on_reload, "w") as f:
|
||||||
|
f.write(f"* Entry more than {max_tokens} words\n")
|
||||||
|
for index in range(max_tokens+1):
|
||||||
|
f.write(f"{index} ")
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# reload embeddings, entries, notes model after adding new org-mode file
|
||||||
|
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# verify newly added org-mode entry is split by max tokens
|
||||||
|
assert len(initial_notes_model.entries) == 12
|
||||||
|
assert len(initial_notes_model.corpus_embeddings) == 12
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
# delete reload test file added
|
||||||
|
content_config.org.input_files = []
|
||||||
|
file_to_add_on_reload.unlink()
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig):
|
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
initial_notes_model= text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
initial_notes_model= text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||||
|
|
||||||
assert len(initial_notes_model.entries) == 10
|
assert len(initial_notes_model.entries) == 10
|
||||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||||
|
|
Loading…
Reference in a new issue