mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Batch entries into smaller groups to process
This commit is contained in:
parent
7e0a692d16
commit
a691ce4aa6
1 changed files with 18 additions and 14 deletions
|
@ -19,6 +19,8 @@ from khoj.processor.embeddings import EmbeddingsModel
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
BATCH_SIZE = 1000 # Define an appropriate batch size
|
||||||
|
|
||||||
|
|
||||||
class Command(BaseCommand):
|
class Command(BaseCommand):
|
||||||
help = "Convert all existing Entry objects to use a new default Search model."
|
help = "Convert all existing Entry objects to use a new default Search model."
|
||||||
|
@ -42,22 +44,24 @@ class Command(BaseCommand):
|
||||||
def handle(self, *args, **options):
|
def handle(self, *args, **options):
|
||||||
@transaction.atomic
|
@transaction.atomic
|
||||||
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
|
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
|
||||||
entries = Entry.objects.filter(entry_filter).all()
|
total_entries = Entry.objects.filter(entry_filter).count()
|
||||||
compiled_entries = [entry.compiled for entry in entries]
|
for start in tqdm(range(0, total_entries, BATCH_SIZE)):
|
||||||
updated_entries: List[Entry] = []
|
end = start + BATCH_SIZE
|
||||||
try:
|
entries = Entry.objects.filter(entry_filter)[start:end]
|
||||||
embeddings = embeddings_model.embed_documents(compiled_entries)
|
compiled_entries = [entry.compiled for entry in entries]
|
||||||
|
updated_entries: List[Entry] = []
|
||||||
|
try:
|
||||||
|
embeddings = embeddings_model.embed_documents(compiled_entries)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error embedding documents: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
except Exception as e:
|
for i, entry in enumerate(entries):
|
||||||
logger.error(f"Error embedding documents: {e}")
|
entry.embeddings = embeddings[i]
|
||||||
return
|
entry.search_model_id = search_model.id
|
||||||
|
updated_entries.append(entry)
|
||||||
|
|
||||||
for i, entry in enumerate(tqdm(entries)):
|
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
|
||||||
entry.embeddings = embeddings[i]
|
|
||||||
entry.search_model_id = search_model.id
|
|
||||||
updated_entries.append(entry)
|
|
||||||
|
|
||||||
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
|
|
||||||
|
|
||||||
search_model_config_id = options.get("search_model_id")
|
search_model_config_id = options.get("search_model_id")
|
||||||
apply = options.get("apply")
|
apply = options.get("apply")
|
||||||
|
|
Loading…
Reference in a new issue