Batch entries into smaller groups to process

This commit is contained in:
sabaimran 2024-10-27 20:43:41 -07:00
parent 7e0a692d16
commit a691ce4aa6

View file

@ -19,6 +19,8 @@ from khoj.processor.embeddings import EmbeddingsModel
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BATCH_SIZE = 1000 # Define an appropriate batch size
class Command(BaseCommand): class Command(BaseCommand):
help = "Convert all existing Entry objects to use a new default Search model." help = "Convert all existing Entry objects to use a new default Search model."
@ -42,17 +44,19 @@ class Command(BaseCommand):
def handle(self, *args, **options): def handle(self, *args, **options):
@transaction.atomic @transaction.atomic
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig): def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
entries = Entry.objects.filter(entry_filter).all() total_entries = Entry.objects.filter(entry_filter).count()
for start in tqdm(range(0, total_entries, BATCH_SIZE)):
end = start + BATCH_SIZE
entries = Entry.objects.filter(entry_filter)[start:end]
compiled_entries = [entry.compiled for entry in entries] compiled_entries = [entry.compiled for entry in entries]
updated_entries: List[Entry] = [] updated_entries: List[Entry] = []
try: try:
embeddings = embeddings_model.embed_documents(compiled_entries) embeddings = embeddings_model.embed_documents(compiled_entries)
except Exception as e: except Exception as e:
logger.error(f"Error embedding documents: {e}") logger.error(f"Error embedding documents: {e}")
return return
for i, entry in enumerate(tqdm(entries)): for i, entry in enumerate(entries):
entry.embeddings = embeddings[i] entry.embeddings = embeddings[i]
entry.search_model_id = search_model.id entry.search_model_id = search_model.id
updated_entries.append(entry) updated_entries.append(entry)