Batch entries into smaller groups to process

This commit is contained in:
sabaimran 2024-10-27 20:43:41 -07:00
parent 7e0a692d16
commit a691ce4aa6

View file

@ -19,6 +19,8 @@ from khoj.processor.embeddings import EmbeddingsModel
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
BATCH_SIZE = 1000 # Define an appropriate batch size
class Command(BaseCommand):
help = "Convert all existing Entry objects to use a new default Search model."
@ -42,22 +44,24 @@ class Command(BaseCommand):
def handle(self, *args, **options):
@transaction.atomic
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
entries = Entry.objects.filter(entry_filter).all()
compiled_entries = [entry.compiled for entry in entries]
updated_entries: List[Entry] = []
try:
embeddings = embeddings_model.embed_documents(compiled_entries)
total_entries = Entry.objects.filter(entry_filter).count()
for start in tqdm(range(0, total_entries, BATCH_SIZE)):
end = start + BATCH_SIZE
entries = Entry.objects.filter(entry_filter)[start:end]
compiled_entries = [entry.compiled for entry in entries]
updated_entries: List[Entry] = []
try:
embeddings = embeddings_model.embed_documents(compiled_entries)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
return
except Exception as e:
logger.error(f"Error embedding documents: {e}")
return
for i, entry in enumerate(entries):
entry.embeddings = embeddings[i]
entry.search_model_id = search_model.id
updated_entries.append(entry)
for i, entry in enumerate(tqdm(entries)):
entry.embeddings = embeddings[i]
entry.search_model_id = search_model.id
updated_entries.append(entry)
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
search_model_config_id = options.get("search_model_id")
apply = options.get("apply")