mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Batch entries into smaller groups to process
This commit is contained in:
parent
7e0a692d16
commit
a691ce4aa6
1 changed files with 18 additions and 14 deletions
|
@ -19,6 +19,8 @@ from khoj.processor.embeddings import EmbeddingsModel
|
|||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BATCH_SIZE = 1000 # Define an appropriate batch size
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Convert all existing Entry objects to use a new default Search model."
|
||||
|
@ -42,22 +44,24 @@ class Command(BaseCommand):
|
|||
def handle(self, *args, **options):
|
||||
@transaction.atomic
|
||||
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
|
||||
entries = Entry.objects.filter(entry_filter).all()
|
||||
compiled_entries = [entry.compiled for entry in entries]
|
||||
updated_entries: List[Entry] = []
|
||||
try:
|
||||
embeddings = embeddings_model.embed_documents(compiled_entries)
|
||||
total_entries = Entry.objects.filter(entry_filter).count()
|
||||
for start in tqdm(range(0, total_entries, BATCH_SIZE)):
|
||||
end = start + BATCH_SIZE
|
||||
entries = Entry.objects.filter(entry_filter)[start:end]
|
||||
compiled_entries = [entry.compiled for entry in entries]
|
||||
updated_entries: List[Entry] = []
|
||||
try:
|
||||
embeddings = embeddings_model.embed_documents(compiled_entries)
|
||||
except Exception as e:
|
||||
logger.error(f"Error embedding documents: {e}")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error embedding documents: {e}")
|
||||
return
|
||||
for i, entry in enumerate(entries):
|
||||
entry.embeddings = embeddings[i]
|
||||
entry.search_model_id = search_model.id
|
||||
updated_entries.append(entry)
|
||||
|
||||
for i, entry in enumerate(tqdm(entries)):
|
||||
entry.embeddings = embeddings[i]
|
||||
entry.search_model_id = search_model.id
|
||||
updated_entries.append(entry)
|
||||
|
||||
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
|
||||
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
|
||||
|
||||
search_model_config_id = options.get("search_model_id")
|
||||
apply = options.get("apply")
|
||||
|
|
Loading…
Reference in a new issue