Add a database lock for jobs that shouldn't be run by multiple workers (#706)

* Add a database lock for jobs that shouldn't be run by multiple workers

* Import relevant functions from utils.helpers
This commit is contained in:
sabaimran 2024-04-16 08:59:27 -07:00 committed by GitHub
parent adb2e8cc5f
commit 91c8b137f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 88 additions and 13 deletions

View file

@ -1,7 +1,7 @@
import json
import logging
import os
from datetime import datetime
from datetime import datetime, timedelta
from enum import Enum
from typing import Optional
@ -24,6 +24,7 @@ from khoj.database.adapters import (
AgentAdapters,
ClientApplicationAdapters,
ConversationAdapters,
ProcessLockAdapters,
SubscriptionState,
aget_or_create_user_by_phone_number,
aget_user_by_phone_number,
@ -32,14 +33,14 @@ from khoj.database.adapters import (
get_all_users,
get_or_create_search_models,
)
from khoj.database.models import ClientApplication, KhojUser, Subscription
from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, configure_search
from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state
from khoj.utils.config import SearchType
from khoj.utils.fs_syncer import collect_files
from khoj.utils.helpers import is_none_or_empty, telemetry_disabled
from khoj.utils.helpers import is_none_or_empty, telemetry_disabled, timer
from khoj.utils.rawconfig import FullConfig
logger = logging.getLogger(__name__)
@ -306,18 +307,28 @@ def configure_middleware(app):
app.add_middleware(SessionMiddleware, secret_key=os.environ.get("KHOJ_DJANGO_SECRET_KEY", "!secret"))
@schedule.repeat(schedule.every(22).to(26).hours)
def update_search_index():
@schedule.repeat(schedule.every(22).to(25).hours)
def update_content_index():
try:
logger.info("📬 Updating content index via Scheduler")
for user in get_all_users():
all_files = collect_files(user=user)
success = configure_content(all_files, user=user)
all_files = collect_files(user=None)
success = configure_content(all_files, user=None)
if not success:
raise RuntimeError("Failed to update content index")
if ProcessLockAdapters.is_process_locked(ProcessLock.Operation.UPDATE_EMBEDDINGS):
logger.info("🔒 Skipping update content index due to lock")
return
ProcessLockAdapters.set_process_lock(
ProcessLock.Operation.UPDATE_EMBEDDINGS, max_duration_in_seconds=60 * 60 * 2
)
with timer("📬 Updating content index via Scheduler"):
for user in get_all_users():
all_files = collect_files(user=user)
success = configure_content(all_files, user=user)
all_files = collect_files(user=None)
success = configure_content(all_files, user=None)
if not success:
raise RuntimeError("Failed to update content index")
logger.info("📪 Content index updated via Scheduler")
ProcessLockAdapters.remove_process_lock(ProcessLock.Operation.UPDATE_EMBEDDINGS)
except Exception as e:
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True)

View file

@ -30,6 +30,7 @@ from khoj.database.models import (
NotionConfig,
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig,
ProcessLock,
ReflectiveQuestion,
SearchModelConfig,
SpeechToTextModelOptions,
@ -402,6 +403,32 @@ async def aget_user_search_model(user: KhojUser):
return config.setting
class ProcessLockAdapters:
@staticmethod
def get_process_lock(process_name: str):
return ProcessLock.objects.filter(name=process_name).first()
@staticmethod
def set_process_lock(process_name: str, max_duration_in_seconds: int = 600):
return ProcessLock.objects.create(name=process_name, max_duration_in_seconds=max_duration_in_seconds)
@staticmethod
def is_process_locked(process_name: str):
process_lock = ProcessLock.objects.filter(name=process_name).first()
if not process_lock:
return False
if process_lock.started_at + timedelta(seconds=process_lock.max_duration_in_seconds) < datetime.now(
tz=timezone.utc
):
process_lock.delete()
return False
return True
@staticmethod
def remove_process_lock(process_name: str):
return ProcessLock.objects.filter(name=process_name).delete()
class ClientApplicationAdapters:
@staticmethod
async def aget_client_application_by_id(client_id: str, client_secret: str):

View file

@ -0,0 +1,26 @@
# Generated by Django 4.2.10 on 2024-04-15 08:48
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0034_alter_chatmodeloptions_chat_model"),
]
operations = [
migrations.CreateModel(
name="ProcessLock",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(choices=[("update_embeddings", "Update Embeddings")], max_length=200)),
("started_at", models.DateTimeField(auto_now_add=True)),
("max_duration_in_seconds", models.IntegerField(default=43200)),
],
options={
"abstract": False,
},
),
]

View file

@ -98,6 +98,17 @@ class Agent(BaseModel):
slug = models.CharField(max_length=200)
class ProcessLock(BaseModel):
class Operation(models.TextChoices):
UPDATE_EMBEDDINGS = "update_embeddings"
# We need to make sure that some operations are thread-safe. To do so, add locks for potentially shared operations.
# For example, we need to make sure that only one process is updating the embeddings at a time.
name = models.CharField(max_length=200, choices=Operation.choices)
started_at = models.DateTimeField(auto_now_add=True)
max_duration_in_seconds = models.IntegerField(default=60 * 60 * 12) # 12 hours
@receiver(pre_save, sender=Agent)
def verify_agent(sender, instance, **kwargs):
# check if this is a new instance