mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Support incremental update of org-mode entries and embeddings
- What - Hash the entries and compare to find new/updated entries - Reuse embeddings encoded for existing entries - Only encode embeddings for updated or new entries - Merge the existing and new entries and embeddings to get the updated entries, embeddings - Why - Given most note text entries are expected to be unchanged across time. Reusing their earlier encoded embeddings should significantly speed up embeddings updates - Previously we were regenerating embeddings for all entries, even if they had existed in previous runs
This commit is contained in:
parent
762607fc9f
commit
2f7a6af56a
5 changed files with 80 additions and 30 deletions
|
@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
# Define Functions
|
||||
def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file):
|
||||
def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file, previous_entries=None):
|
||||
# Input Validation
|
||||
if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter):
|
||||
print("At least one of beancount-files or beancount-file-filter is required to be specified")
|
||||
|
@ -39,7 +39,7 @@ def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file):
|
|||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
|
||||
return entries
|
||||
return list(enumerate(entries))
|
||||
|
||||
|
||||
def get_beancount_files(beancount_files=None, beancount_file_filter=None):
|
||||
|
|
|
@ -39,7 +39,7 @@ def markdown_to_jsonl(markdown_files, markdown_file_filter, output_file):
|
|||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
|
||||
return entries
|
||||
return list(enumerate(entries))
|
||||
|
||||
|
||||
def get_markdown_files(markdown_files=None, markdown_file_filter=None):
|
||||
|
|
|
@ -7,6 +7,7 @@ import argparse
|
|||
import pathlib
|
||||
import glob
|
||||
import logging
|
||||
import hashlib
|
||||
|
||||
# Internal Packages
|
||||
from src.processor.org_mode import orgnode
|
||||
|
@ -19,7 +20,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
# Define Functions
|
||||
def org_to_jsonl(org_files, org_file_filter, output_file):
|
||||
def org_to_jsonl(org_files, org_file_filter, output_file, previous_entries=None):
|
||||
# Input Validation
|
||||
if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter):
|
||||
print("At least one of org-files or org-file-filter is required to be specified")
|
||||
|
@ -29,10 +30,41 @@ def org_to_jsonl(org_files, org_file_filter, output_file):
|
|||
org_files = get_org_files(org_files, org_file_filter)
|
||||
|
||||
# Extract Entries from specified Org files
|
||||
entries, file_to_entries = extract_org_entries(org_files)
|
||||
entry_nodes, file_to_entries = extract_org_entries(org_files)
|
||||
current_entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
# Hash all current and previous entries to identify new entries
|
||||
current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(json.dumps(e), encoding='utf-8')).hexdigest(), current_entries))
|
||||
previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(json.dumps(e), encoding='utf-8')).hexdigest(), previous_entries))
|
||||
|
||||
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
||||
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
|
||||
|
||||
# All entries that did not exist in the previous set are to be added
|
||||
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
|
||||
# All entries that exist in both current and previous sets are kept
|
||||
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
|
||||
|
||||
# Mark new entries with no ids for later embeddings generation
|
||||
new_entries = [
|
||||
(None, hash_to_current_entries[entry_hash])
|
||||
for entry_hash in new_entry_hashes
|
||||
]
|
||||
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
|
||||
existing_entries = [
|
||||
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
|
||||
for entry_hash in existing_entry_hashes
|
||||
]
|
||||
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
|
||||
entries_with_ids = existing_entries_sorted + new_entries
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_data = convert_org_entries_to_jsonl(entries, file_to_entries)
|
||||
entries = map(lambda entry: entry[1], entries_with_ids)
|
||||
jsonl_data = convert_org_entries_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
|
@ -40,7 +72,7 @@ def org_to_jsonl(org_files, org_file_filter, output_file):
|
|||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
|
||||
return entries
|
||||
return entries_with_ids
|
||||
|
||||
|
||||
def get_org_files(org_files=None, org_file_filter=None):
|
||||
|
@ -70,16 +102,16 @@ def extract_org_entries(org_files):
|
|||
entry_to_file_map = []
|
||||
for org_file in org_files:
|
||||
org_file_entries = orgnode.makelist(str(org_file))
|
||||
entry_to_file_map += [org_file]*len(org_file_entries)
|
||||
entry_to_file_map += zip(org_file_entries, [org_file]*len(org_file_entries))
|
||||
entries.extend(org_file_entries)
|
||||
|
||||
return entries, entry_to_file_map
|
||||
return entries, dict(entry_to_file_map)
|
||||
|
||||
|
||||
def convert_org_entries_to_jsonl(entries, entry_to_file_map) -> str:
|
||||
"Convert each Org-Mode entries to JSON and collate as JSONL"
|
||||
jsonl = ''
|
||||
for entry_id, entry in enumerate(entries):
|
||||
def convert_org_nodes_to_entries(entries: list[orgnode.Orgnode], entry_to_file_map) -> list[dict]:
|
||||
"Convert Org-Mode entries into list of dictionary"
|
||||
entry_maps = []
|
||||
for entry in entries:
|
||||
entry_dict = dict()
|
||||
|
||||
# Ignore title notes i.e notes with just headings and empty body
|
||||
|
@ -113,14 +145,17 @@ def convert_org_entries_to_jsonl(entries, entry_to_file_map) -> str:
|
|||
|
||||
if entry_dict:
|
||||
entry_dict["raw"] = f'{entry}'
|
||||
entry_dict["file"] = f'{entry_to_file_map[entry_id]}'
|
||||
entry_dict["file"] = f'{entry_to_file_map[entry]}'
|
||||
|
||||
# Convert Dictionary to JSON and Append to JSONL string
|
||||
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
|
||||
entry_maps.append(entry_dict)
|
||||
|
||||
logger.info(f"Converted {len(entries)} to jsonl format")
|
||||
return entry_maps
|
||||
|
||||
return jsonl
|
||||
|
||||
def convert_org_entries_to_jsonl(entries) -> str:
|
||||
"Convert each Org-Mode entry to JSON and collate as JSONL"
|
||||
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -55,15 +55,28 @@ def extract_entries(jsonl_file):
|
|||
return load_jsonl(jsonl_file)
|
||||
|
||||
|
||||
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
|
||||
def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate=False):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
# Load pre-computed embeddings from file if exists
|
||||
new_entries = []
|
||||
# Load pre-computed embeddings from file if exists and update them if required
|
||||
if embeddings_file.exists() and not regenerate:
|
||||
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||
logger.info(f"Loaded embeddings from {embeddings_file}")
|
||||
|
||||
else: # Else compute the corpus_embeddings from scratch, which can take a while
|
||||
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||
# Encode any new entries in the corpus and update corpus embeddings
|
||||
new_entries = [entry['compiled'] for id, entry in entries_with_ids if id is None]
|
||||
if new_entries:
|
||||
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||
existing_entry_ids = [id for id, _ in entries_with_ids if id is not None]
|
||||
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor()
|
||||
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
||||
# Else compute the corpus embeddings from scratch
|
||||
else:
|
||||
new_entries = [entry['compiled'] for _, entry in entries_with_ids]
|
||||
corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||
|
||||
# Save regenerated or updated embeddings to file
|
||||
if new_entries:
|
||||
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||
torch.save(corpus_embeddings, embeddings_file)
|
||||
logger.info(f"Computed embeddings and saved them to {embeddings_file}")
|
||||
|
@ -169,16 +182,16 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
|
|||
|
||||
# Map notes in text files to (compressed) JSONL formatted file
|
||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
||||
if not config.compressed_jsonl.exists() or regenerate:
|
||||
text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl)
|
||||
previous_entries = extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() else None
|
||||
entries_with_indices = text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, previous_entries)
|
||||
|
||||
# Extract Entries
|
||||
# Extract Updated Entries
|
||||
entries = extract_entries(config.compressed_jsonl)
|
||||
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
|
||||
|
||||
# Compute or Load Embeddings
|
||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
|
||||
corpus_embeddings = compute_embeddings(entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate)
|
||||
|
||||
for filter in filters:
|
||||
filter.load(entries, regenerate=regenerate)
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
from posixpath import split
|
||||
|
||||
# Internal Packages
|
||||
from src.processor.org_mode.org_to_jsonl import convert_org_entries_to_jsonl, extract_org_entries
|
||||
from src.processor.org_mode.org_to_jsonl import convert_org_entries_to_jsonl, convert_org_nodes_to_entries, extract_org_entries
|
||||
from src.utils.helpers import is_none_or_empty
|
||||
|
||||
|
||||
|
@ -21,10 +21,11 @@ def test_entry_with_empty_body_line_to_jsonl(tmp_path):
|
|||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries, entry_to_file_map = extract_org_entries(org_files=[orgfile])
|
||||
entry_nodes, file_to_entries = extract_org_entries(org_files=[orgfile])
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_data = convert_org_entries_to_jsonl(entries, entry_to_file_map)
|
||||
entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
||||
jsonl_data = convert_org_entries_to_jsonl(entries)
|
||||
|
||||
# Assert
|
||||
assert is_none_or_empty(jsonl_data)
|
||||
|
@ -43,10 +44,11 @@ def test_entry_with_body_to_jsonl(tmp_path):
|
|||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries, entry_to_file_map = extract_org_entries(org_files=[orgfile])
|
||||
entry_nodes, file_to_entries = extract_org_entries(org_files=[orgfile])
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = convert_org_entries_to_jsonl(entries, entry_to_file_map)
|
||||
entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
||||
jsonl_string = convert_org_entries_to_jsonl(entries)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
|
|
Loading…
Reference in a new issue