Batch entries into smaller groups to process

This commit is contained in:
sabaimran 2024-10-27 20:43:41 -07:00
parent 7e0a692d16
commit a691ce4aa6

View file

@ -19,6 +19,8 @@ from khoj.processor.embeddings import EmbeddingsModel
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BATCH_SIZE = 1000 # Define an appropriate batch size
class Command(BaseCommand): class Command(BaseCommand):
help = "Convert all existing Entry objects to use a new default Search model." help = "Convert all existing Entry objects to use a new default Search model."
@ -42,22 +44,24 @@ class Command(BaseCommand):
def handle(self, *args, **options): def handle(self, *args, **options):
@transaction.atomic @transaction.atomic
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig): def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
entries = Entry.objects.filter(entry_filter).all() total_entries = Entry.objects.filter(entry_filter).count()
compiled_entries = [entry.compiled for entry in entries] for start in tqdm(range(0, total_entries, BATCH_SIZE)):
updated_entries: List[Entry] = [] end = start + BATCH_SIZE
try: entries = Entry.objects.filter(entry_filter)[start:end]
embeddings = embeddings_model.embed_documents(compiled_entries) compiled_entries = [entry.compiled for entry in entries]
updated_entries: List[Entry] = []
try:
embeddings = embeddings_model.embed_documents(compiled_entries)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
return
except Exception as e: for i, entry in enumerate(entries):
logger.error(f"Error embedding documents: {e}") entry.embeddings = embeddings[i]
return entry.search_model_id = search_model.id
updated_entries.append(entry)
for i, entry in enumerate(tqdm(entries)): Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
entry.embeddings = embeddings[i]
entry.search_model_id = search_model.id
updated_entries.append(entry)
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
search_model_config_id = options.get("search_model_id") search_model_config_id = options.get("search_model_id")
apply = options.get("apply") apply = options.get("apply")