diff --git a/src/processor/ledger/beancount_to_jsonl.py b/src/processor/ledger/beancount_to_jsonl.py index ccad97da..9f37df70 100644 --- a/src/processor/ledger/beancount_to_jsonl.py +++ b/src/processor/ledger/beancount_to_jsonl.py @@ -35,6 +35,12 @@ class BeancountToJsonl(TextToJsonl): end = time.time() logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds") + # Split entries by max tokens supported by model + start = time.time() + current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) + end = time.time() + logger.debug(f"Split entries by max token size supported by model: {end - start} seconds") + # Identify, mark and merge any new entries with previous entries start = time.time() if not previous_entries: diff --git a/src/processor/markdown/markdown_to_jsonl.py b/src/processor/markdown/markdown_to_jsonl.py index 5c4d660d..17482de5 100644 --- a/src/processor/markdown/markdown_to_jsonl.py +++ b/src/processor/markdown/markdown_to_jsonl.py @@ -35,6 +35,12 @@ class MarkdownToJsonl(TextToJsonl): end = time.time() logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds") + # Split entries by max tokens supported by model + start = time.time() + current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) + end = time.time() + logger.debug(f"Split entries by max token size supported by model: {end - start} seconds") + # Identify, mark and merge any new entries with previous entries start = time.time() if not previous_entries: diff --git a/src/processor/org_mode/org_to_jsonl.py b/src/processor/org_mode/org_to_jsonl.py index 52441a99..313c9a3f 100644 --- a/src/processor/org_mode/org_to_jsonl.py +++ b/src/processor/org_mode/org_to_jsonl.py @@ -41,7 +41,12 @@ class OrgToJsonl(TextToJsonl): start = time.time() current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries) end = time.time() - logger.debug(f"Convert OrgNodes into entry dictionaries: {end - start} seconds") + logger.debug(f"Convert OrgNodes into list of entries: {end - start} seconds") + + start = time.time() + current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) + end = time.time() + logger.debug(f"Split entries by max token size supported by model: {end - start} seconds") # Identify, mark and merge any new entries with previous entries if not previous_entries: diff --git a/src/processor/text_to_jsonl.py b/src/processor/text_to_jsonl.py index 2f5e7e40..0eb60e6c 100644 --- a/src/processor/text_to_jsonl.py +++ b/src/processor/text_to_jsonl.py @@ -23,6 +23,19 @@ class TextToJsonl(ABC): def hash_func(key: str) -> Callable: return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest() + @staticmethod + def split_entries_by_max_tokens(entries: list[Entry], max_tokens: int=256) -> list[Entry]: + "Split entries if compiled entry length exceeds the max tokens supported by the ML model." + chunked_entries: list[Entry] = [] + for entry in entries: + compiled_entry_words = entry.compiled.split() + for chunk_index in range(0, len(compiled_entry_words), max_tokens): + compiled_entry_words_chunk = compiled_entry_words[chunk_index:chunk_index + max_tokens] + compiled_entry_chunk = ' '.join(compiled_entry_words_chunk) + entry_chunk = Entry(compiled=compiled_entry_chunk, raw=entry.raw, file=entry.file) + chunked_entries.append(entry_chunk) + return chunked_entries + def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]: # Hash all current and previous entries to identify new entries start = time.time() diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 8b29c517..5bbbdd64 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -150,6 +150,17 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): end = time.time() logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}") + # Deduplicate entries by raw entry text before showing to users + # Compiled entries are split by max tokens supported by ML models. + # This can result in duplicate hits, entries shown to user. + start = time.time() + seen, original_hits_count = set(), len(hits) + hits = [hit for hit in hits + if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)] + duplicate_hits = original_hits_count - len(hits) + end = time.time() + logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates") + return hits, entries diff --git a/tests/conftest.py b/tests/conftest.py index 103a28e8..ec87f964 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,6 @@ import pytest # Internal Packages from src.search_type import image_search, text_search -from src.utils.config import SearchType from src.utils.helpers import resolve_absolute_path from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig from src.processor.org_mode.org_to_jsonl import OrgToJsonl diff --git a/tests/test_org_to_jsonl.py b/tests/test_org_to_jsonl.py index 2dbedcd0..fe64cc67 100644 --- a/tests/test_org_to_jsonl.py +++ b/tests/test_org_to_jsonl.py @@ -3,6 +3,7 @@ import json # Internal Packages from src.processor.org_mode.org_to_jsonl import OrgToJsonl +from src.processor.text_to_jsonl import TextToJsonl from src.utils.helpers import is_none_or_empty @@ -35,6 +36,31 @@ def test_configure_heading_entry_to_jsonl(tmp_path): assert is_none_or_empty(jsonl_data) +def test_entry_split_when_exceeds_max_words(tmp_path): + "Ensure entries with compiled words exceeding max_words are split." + # Arrange + entry = f'''*** Heading + \t\r + Body Line 1 + ''' + orgfile = create_file(tmp_path, entry) + + # Act + # Extract Entries from specified Org files + entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile]) + + # Split each entry from specified Org files by max words + jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl( + TextToJsonl.split_entries_by_max_tokens( + OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), + max_tokens = 2) + ) + jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + + # Assert + assert len(jsonl_data) == 2 + + def test_entry_with_body_to_jsonl(tmp_path): "Ensure entries with valid body text are loaded." # Arrange diff --git a/tests/test_text_search.py b/tests/test_text_search.py index e05831a1..dcacf7fb 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -80,10 +80,43 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC assert "git clone" in search_result +# ---------------------------------------------------------------------------------------------------- +def test_entry_chunking_by_max_tokens(content_config: ContentConfig, search_config: SearchConfig): + # Arrange + initial_notes_model= text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) + + assert len(initial_notes_model.entries) == 10 + assert len(initial_notes_model.corpus_embeddings) == 10 + + file_to_add_on_reload = Path(content_config.org.input_filter[0]).parent / "entry_exceeding_max_tokens.org" + content_config.org.input_files = [f'{file_to_add_on_reload}'] + + # Insert org-mode entry with size exceeding max token limit to new org file + max_tokens = 256 + with open(file_to_add_on_reload, "w") as f: + f.write(f"* Entry more than {max_tokens} words\n") + for index in range(max_tokens+1): + f.write(f"{index} ") + + # Act + # reload embeddings, entries, notes model after adding new org-mode file + initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) + + # Assert + # verify newly added org-mode entry is split by max tokens + assert len(initial_notes_model.entries) == 12 + assert len(initial_notes_model.corpus_embeddings) == 12 + + # Cleanup + # delete reload test file added + content_config.org.input_files = [] + file_to_add_on_reload.unlink() + + # ---------------------------------------------------------------------------------------------------- def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig): # Arrange - initial_notes_model= text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) + initial_notes_model= text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.corpus_embeddings) == 10