mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Flatten nested loops, improve progress reporting in text_to_jsonl indexer
Flatten the nested loops to improve visibilty into indexing progress Reduce spurious logs, report the logs at aggregated level and update the logging description text to improve indexing progress reporting
This commit is contained in:
parent
12b5ef6540
commit
dc9946fc03
4 changed files with 61 additions and 64 deletions
|
@ -1,11 +1,12 @@
|
|||
# Standard Packages
|
||||
from abc import ABC, abstractmethod
|
||||
import hashlib
|
||||
from itertools import repeat
|
||||
import logging
|
||||
import uuid
|
||||
from tqdm import tqdm
|
||||
from typing import Callable, List, Tuple, Set, Any
|
||||
from khoj.utils.helpers import timer, batcher
|
||||
from khoj.utils.helpers import is_none_or_empty, timer, batcher
|
||||
|
||||
|
||||
# Internal Packages
|
||||
|
@ -83,92 +84,88 @@ class TextToEntries(ABC):
|
|||
user: KhojUser = None,
|
||||
regenerate: bool = False,
|
||||
):
|
||||
with timer("Construct current entry hashes", logger):
|
||||
with timer("Constructed current entry hashes in", logger):
|
||||
hashes_by_file = dict[str, set[str]]()
|
||||
current_entry_hashes = list(map(TextToEntries.hash_func(key), current_entries))
|
||||
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
||||
for entry in tqdm(current_entries, desc="Hashing Entries"):
|
||||
hashes_by_file.setdefault(entry.file, set()).add(TextToEntries.hash_func(key)(entry))
|
||||
|
||||
num_deleted_embeddings = 0
|
||||
with timer("Preparing dataset for regeneration", logger):
|
||||
if regenerate:
|
||||
logger.debug(f"Deleting all embeddings for file type {file_type}")
|
||||
num_deleted_embeddings = EntryAdapters.delete_all_entries(user, file_type)
|
||||
num_deleted_entries = 0
|
||||
if regenerate:
|
||||
with timer("Prepared dataset for regeneration in", logger):
|
||||
logger.debug(f"Deleting all entries for file type {file_type}")
|
||||
num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type)
|
||||
|
||||
num_new_embeddings = 0
|
||||
with timer("Identify hashes for adding new entries", logger):
|
||||
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
|
||||
hashes_to_process = set()
|
||||
with timer("Identified entries to add to database in", logger):
|
||||
for file in tqdm(hashes_by_file, desc="Identify new entries"):
|
||||
hashes_for_file = hashes_by_file[file]
|
||||
hashes_to_process = set()
|
||||
existing_entries = DbEntry.objects.filter(
|
||||
user=user, hashed_value__in=hashes_for_file, file_type=file_type
|
||||
)
|
||||
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
|
||||
hashes_to_process = hashes_for_file - existing_entry_hashes
|
||||
hashes_to_process |= hashes_for_file - existing_entry_hashes
|
||||
|
||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||
embeddings = self.embeddings_model.embed_documents(data_to_embed)
|
||||
embeddings = []
|
||||
with timer("Generated embeddings for entries to add to database in", logger):
|
||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||
embeddings += self.embeddings_model.embed_documents(data_to_embed)
|
||||
|
||||
with timer("Update the database with new vector embeddings", logger):
|
||||
num_items = len(hashes_to_process)
|
||||
assert num_items == len(embeddings)
|
||||
batch_size = min(200, num_items)
|
||||
entry_batches = zip(hashes_to_process, embeddings)
|
||||
added_entries: list[DbEntry] = []
|
||||
with timer("Added entries to database in", logger):
|
||||
num_items = len(hashes_to_process)
|
||||
assert num_items == len(embeddings)
|
||||
batch_size = min(200, num_items)
|
||||
entry_batches = zip(hashes_to_process, embeddings)
|
||||
|
||||
for entry_batch in tqdm(
|
||||
batcher(entry_batches, batch_size), desc="Processing embeddings in batches"
|
||||
):
|
||||
batch_embeddings_to_create = []
|
||||
for entry_hash, new_entry in entry_batch:
|
||||
entry = hash_to_current_entries[entry_hash]
|
||||
batch_embeddings_to_create.append(
|
||||
DbEntry(
|
||||
user=user,
|
||||
embeddings=new_entry,
|
||||
raw=entry.raw,
|
||||
compiled=entry.compiled,
|
||||
heading=entry.heading[:1000], # Truncate to max chars of field allowed
|
||||
file_path=entry.file,
|
||||
file_type=file_type,
|
||||
hashed_value=entry_hash,
|
||||
corpus_id=entry.corpus_id,
|
||||
)
|
||||
)
|
||||
new_entries = DbEntry.objects.bulk_create(batch_embeddings_to_create)
|
||||
logger.debug(f"Created {len(new_entries)} new embeddings")
|
||||
num_new_embeddings += len(new_entries)
|
||||
for entry_batch in tqdm(batcher(entry_batches, batch_size), desc="Add entries to database"):
|
||||
batch_embeddings_to_create = []
|
||||
for entry_hash, new_entry in entry_batch:
|
||||
entry = hash_to_current_entries[entry_hash]
|
||||
batch_embeddings_to_create.append(
|
||||
DbEntry(
|
||||
user=user,
|
||||
embeddings=new_entry,
|
||||
raw=entry.raw,
|
||||
compiled=entry.compiled,
|
||||
heading=entry.heading[:1000], # Truncate to max chars of field allowed
|
||||
file_path=entry.file,
|
||||
file_type=file_type,
|
||||
hashed_value=entry_hash,
|
||||
corpus_id=entry.corpus_id,
|
||||
)
|
||||
)
|
||||
added_entries += DbEntry.objects.bulk_create(batch_embeddings_to_create)
|
||||
logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
|
||||
|
||||
dates_to_create = []
|
||||
with timer("Create new date associations for new embeddings", logger):
|
||||
for new_entry in new_entries:
|
||||
dates = self.date_filter.extract_dates(new_entry.raw)
|
||||
for date in dates:
|
||||
dates_to_create.append(
|
||||
EntryDates(
|
||||
date=date,
|
||||
entry=new_entry,
|
||||
)
|
||||
)
|
||||
new_dates = EntryDates.objects.bulk_create(dates_to_create)
|
||||
if len(new_dates) > 0:
|
||||
logger.debug(f"Created {len(new_dates)} new date entries")
|
||||
new_dates = []
|
||||
with timer("Indexed dates from added entries in", logger):
|
||||
for added_entry in added_entries:
|
||||
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.raw), repeat(added_entry))
|
||||
dates_to_create = [
|
||||
EntryDates(date=date, entry=added_entry)
|
||||
for date, added_entry in dates_in_entries
|
||||
if not is_none_or_empty(date)
|
||||
]
|
||||
new_dates += EntryDates.objects.bulk_create(dates_to_create)
|
||||
logger.debug(f"Indexed {len(new_dates)} dates from added {file_type} entries")
|
||||
|
||||
with timer("Identify hashes for removed entries", logger):
|
||||
with timer("Deleted entries identified by server from database in", logger):
|
||||
for file in hashes_by_file:
|
||||
existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
|
||||
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
|
||||
num_deleted_embeddings += len(to_delete_entry_hashes)
|
||||
num_deleted_entries += len(to_delete_entry_hashes)
|
||||
EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
|
||||
|
||||
with timer("Identify hashes for deleting entries", logger):
|
||||
with timer("Deleted entries requested by clients from database in", logger):
|
||||
if deletion_filenames is not None:
|
||||
for file_path in deletion_filenames:
|
||||
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
||||
num_deleted_embeddings += deleted_count
|
||||
num_deleted_entries += deleted_count
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
return len(added_entries), num_deleted_entries
|
||||
|
||||
@staticmethod
|
||||
def mark_entries_for_update(
|
||||
|
|
|
@ -321,7 +321,6 @@ def load_content(
|
|||
content_index: Optional[ContentIndex],
|
||||
search_models: SearchModels,
|
||||
):
|
||||
logger.info(f"Loading content from existing embeddings...")
|
||||
if content_config is None:
|
||||
logger.warning("🚨 No Content configuration available.")
|
||||
return None
|
||||
|
|
|
@ -207,7 +207,7 @@ def setup(
|
|||
file_names = [file_name for file_name in files]
|
||||
|
||||
logger.info(
|
||||
f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}"
|
||||
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -52,7 +52,8 @@ def cli(args=None):
|
|||
|
||||
args, remaining_args = parser.parse_known_args(args)
|
||||
|
||||
logger.debug(f"Ignoring unknown commandline args: {remaining_args}")
|
||||
if len(remaining_args) > 0:
|
||||
logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}")
|
||||
|
||||
# Set default values for arguments
|
||||
args.chat_on_gpu = not args.disable_chat_on_gpu
|
||||
|
|
Loading…
Add table
Reference in a new issue