Flatten nested loops, improve progress reporting in text_to_jsonl indexer

Flatten the nested loops to improve visibilty into indexing progress

Reduce spurious logs, report the logs at aggregated level and update
the logging description text to improve indexing progress reporting
This commit is contained in:
Debanjum Singh Solanky 2023-11-04 04:55:51 -07:00
parent 12b5ef6540
commit dc9946fc03
4 changed files with 61 additions and 64 deletions

View file

@ -1,11 +1,12 @@
# Standard Packages
from abc import ABC, abstractmethod
import hashlib
from itertools import repeat
import logging
import uuid
from tqdm import tqdm
from typing import Callable, List, Tuple, Set, Any
from khoj.utils.helpers import timer, batcher
from khoj.utils.helpers import is_none_or_empty, timer, batcher
# Internal Packages
@ -83,92 +84,88 @@ class TextToEntries(ABC):
user: KhojUser = None,
regenerate: bool = False,
):
with timer("Construct current entry hashes", logger):
with timer("Constructed current entry hashes in", logger):
hashes_by_file = dict[str, set[str]]()
current_entry_hashes = list(map(TextToEntries.hash_func(key), current_entries))
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
for entry in tqdm(current_entries, desc="Hashing Entries"):
hashes_by_file.setdefault(entry.file, set()).add(TextToEntries.hash_func(key)(entry))
num_deleted_embeddings = 0
with timer("Preparing dataset for regeneration", logger):
if regenerate:
logger.debug(f"Deleting all embeddings for file type {file_type}")
num_deleted_embeddings = EntryAdapters.delete_all_entries(user, file_type)
num_deleted_entries = 0
if regenerate:
with timer("Prepared dataset for regeneration in", logger):
logger.debug(f"Deleting all entries for file type {file_type}")
num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type)
num_new_embeddings = 0
with timer("Identify hashes for adding new entries", logger):
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
hashes_to_process = set()
with timer("Identified entries to add to database in", logger):
for file in tqdm(hashes_by_file, desc="Identify new entries"):
hashes_for_file = hashes_by_file[file]
hashes_to_process = set()
existing_entries = DbEntry.objects.filter(
user=user, hashed_value__in=hashes_for_file, file_type=file_type
)
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
hashes_to_process = hashes_for_file - existing_entry_hashes
hashes_to_process |= hashes_for_file - existing_entry_hashes
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
embeddings = self.embeddings_model.embed_documents(data_to_embed)
embeddings = []
with timer("Generated embeddings for entries to add to database in", logger):
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
embeddings += self.embeddings_model.embed_documents(data_to_embed)
with timer("Update the database with new vector embeddings", logger):
num_items = len(hashes_to_process)
assert num_items == len(embeddings)
batch_size = min(200, num_items)
entry_batches = zip(hashes_to_process, embeddings)
added_entries: list[DbEntry] = []
with timer("Added entries to database in", logger):
num_items = len(hashes_to_process)
assert num_items == len(embeddings)
batch_size = min(200, num_items)
entry_batches = zip(hashes_to_process, embeddings)
for entry_batch in tqdm(
batcher(entry_batches, batch_size), desc="Processing embeddings in batches"
):
batch_embeddings_to_create = []
for entry_hash, new_entry in entry_batch:
entry = hash_to_current_entries[entry_hash]
batch_embeddings_to_create.append(
DbEntry(
user=user,
embeddings=new_entry,
raw=entry.raw,
compiled=entry.compiled,
heading=entry.heading[:1000], # Truncate to max chars of field allowed
file_path=entry.file,
file_type=file_type,
hashed_value=entry_hash,
corpus_id=entry.corpus_id,
)
)
new_entries = DbEntry.objects.bulk_create(batch_embeddings_to_create)
logger.debug(f"Created {len(new_entries)} new embeddings")
num_new_embeddings += len(new_entries)
for entry_batch in tqdm(batcher(entry_batches, batch_size), desc="Add entries to database"):
batch_embeddings_to_create = []
for entry_hash, new_entry in entry_batch:
entry = hash_to_current_entries[entry_hash]
batch_embeddings_to_create.append(
DbEntry(
user=user,
embeddings=new_entry,
raw=entry.raw,
compiled=entry.compiled,
heading=entry.heading[:1000], # Truncate to max chars of field allowed
file_path=entry.file,
file_type=file_type,
hashed_value=entry_hash,
corpus_id=entry.corpus_id,
)
)
added_entries += DbEntry.objects.bulk_create(batch_embeddings_to_create)
logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
dates_to_create = []
with timer("Create new date associations for new embeddings", logger):
for new_entry in new_entries:
dates = self.date_filter.extract_dates(new_entry.raw)
for date in dates:
dates_to_create.append(
EntryDates(
date=date,
entry=new_entry,
)
)
new_dates = EntryDates.objects.bulk_create(dates_to_create)
if len(new_dates) > 0:
logger.debug(f"Created {len(new_dates)} new date entries")
new_dates = []
with timer("Indexed dates from added entries in", logger):
for added_entry in added_entries:
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.raw), repeat(added_entry))
dates_to_create = [
EntryDates(date=date, entry=added_entry)
for date, added_entry in dates_in_entries
if not is_none_or_empty(date)
]
new_dates += EntryDates.objects.bulk_create(dates_to_create)
logger.debug(f"Indexed {len(new_dates)} dates from added {file_type} entries")
with timer("Identify hashes for removed entries", logger):
with timer("Deleted entries identified by server from database in", logger):
for file in hashes_by_file:
existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
num_deleted_embeddings += len(to_delete_entry_hashes)
num_deleted_entries += len(to_delete_entry_hashes)
EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
with timer("Identify hashes for deleting entries", logger):
with timer("Deleted entries requested by clients from database in", logger):
if deletion_filenames is not None:
for file_path in deletion_filenames:
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_embeddings += deleted_count
num_deleted_entries += deleted_count
return num_new_embeddings, num_deleted_embeddings
return len(added_entries), num_deleted_entries
@staticmethod
def mark_entries_for_update(

View file

@ -321,7 +321,6 @@ def load_content(
content_index: Optional[ContentIndex],
search_models: SearchModels,
):
logger.info(f"Loading content from existing embeddings...")
if content_config is None:
logger.warning("🚨 No Content configuration available.")
return None

View file

@ -207,7 +207,7 @@ def setup(
file_names = [file_name for file_name in files]
logger.info(
f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}"
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
)

View file

@ -52,7 +52,8 @@ def cli(args=None):
args, remaining_args = parser.parse_known_args(args)
logger.debug(f"Ignoring unknown commandline args: {remaining_args}")
if len(remaining_args) > 0:
logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}")
# Set default values for arguments
args.chat_on_gpu = not args.disable_chat_on_gpu