mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Add a database lock for jobs that shouldn't be run by multiple workers (#706)
* Add a database lock for jobs that shouldn't be run by multiple workers * Import relevant functions from utils.helpers
This commit is contained in:
parent
adb2e8cc5f
commit
91c8b137f1
4 changed files with 88 additions and 13 deletions
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
@ -24,6 +24,7 @@ from khoj.database.adapters import (
|
|||
AgentAdapters,
|
||||
ClientApplicationAdapters,
|
||||
ConversationAdapters,
|
||||
ProcessLockAdapters,
|
||||
SubscriptionState,
|
||||
aget_or_create_user_by_phone_number,
|
||||
aget_user_by_phone_number,
|
||||
|
@ -32,14 +33,14 @@ from khoj.database.adapters import (
|
|||
get_all_users,
|
||||
get_or_create_search_models,
|
||||
)
|
||||
from khoj.database.models import ClientApplication, KhojUser, Subscription
|
||||
from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.routers.indexer import configure_content, configure_search
|
||||
from khoj.routers.twilio import is_twilio_enabled
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import SearchType
|
||||
from khoj.utils.fs_syncer import collect_files
|
||||
from khoj.utils.helpers import is_none_or_empty, telemetry_disabled
|
||||
from khoj.utils.helpers import is_none_or_empty, telemetry_disabled, timer
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -306,18 +307,28 @@ def configure_middleware(app):
|
|||
app.add_middleware(SessionMiddleware, secret_key=os.environ.get("KHOJ_DJANGO_SECRET_KEY", "!secret"))
|
||||
|
||||
|
||||
@schedule.repeat(schedule.every(22).to(26).hours)
|
||||
def update_search_index():
|
||||
@schedule.repeat(schedule.every(22).to(25).hours)
|
||||
def update_content_index():
|
||||
try:
|
||||
logger.info("📬 Updating content index via Scheduler")
|
||||
for user in get_all_users():
|
||||
all_files = collect_files(user=user)
|
||||
success = configure_content(all_files, user=user)
|
||||
all_files = collect_files(user=None)
|
||||
success = configure_content(all_files, user=None)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to update content index")
|
||||
if ProcessLockAdapters.is_process_locked(ProcessLock.Operation.UPDATE_EMBEDDINGS):
|
||||
logger.info("🔒 Skipping update content index due to lock")
|
||||
return
|
||||
ProcessLockAdapters.set_process_lock(
|
||||
ProcessLock.Operation.UPDATE_EMBEDDINGS, max_duration_in_seconds=60 * 60 * 2
|
||||
)
|
||||
|
||||
with timer("📬 Updating content index via Scheduler"):
|
||||
for user in get_all_users():
|
||||
all_files = collect_files(user=user)
|
||||
success = configure_content(all_files, user=user)
|
||||
all_files = collect_files(user=None)
|
||||
success = configure_content(all_files, user=None)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to update content index")
|
||||
|
||||
logger.info("📪 Content index updated via Scheduler")
|
||||
|
||||
ProcessLockAdapters.remove_process_lock(ProcessLock.Operation.UPDATE_EMBEDDINGS)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True)
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ from khoj.database.models import (
|
|||
NotionConfig,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
OpenAIProcessorConversationConfig,
|
||||
ProcessLock,
|
||||
ReflectiveQuestion,
|
||||
SearchModelConfig,
|
||||
SpeechToTextModelOptions,
|
||||
|
@ -402,6 +403,32 @@ async def aget_user_search_model(user: KhojUser):
|
|||
return config.setting
|
||||
|
||||
|
||||
class ProcessLockAdapters:
|
||||
@staticmethod
|
||||
def get_process_lock(process_name: str):
|
||||
return ProcessLock.objects.filter(name=process_name).first()
|
||||
|
||||
@staticmethod
|
||||
def set_process_lock(process_name: str, max_duration_in_seconds: int = 600):
|
||||
return ProcessLock.objects.create(name=process_name, max_duration_in_seconds=max_duration_in_seconds)
|
||||
|
||||
@staticmethod
|
||||
def is_process_locked(process_name: str):
|
||||
process_lock = ProcessLock.objects.filter(name=process_name).first()
|
||||
if not process_lock:
|
||||
return False
|
||||
if process_lock.started_at + timedelta(seconds=process_lock.max_duration_in_seconds) < datetime.now(
|
||||
tz=timezone.utc
|
||||
):
|
||||
process_lock.delete()
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def remove_process_lock(process_name: str):
|
||||
return ProcessLock.objects.filter(name=process_name).delete()
|
||||
|
||||
|
||||
class ClientApplicationAdapters:
|
||||
@staticmethod
|
||||
async def aget_client_application_by_id(client_id: str, client_secret: str):
|
||||
|
|
26
src/khoj/database/migrations/0035_processlock.py
Normal file
26
src/khoj/database/migrations/0035_processlock.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Generated by Django 4.2.10 on 2024-04-15 08:48
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0034_alter_chatmodeloptions_chat_model"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="ProcessLock",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("name", models.CharField(choices=[("update_embeddings", "Update Embeddings")], max_length=200)),
|
||||
("started_at", models.DateTimeField(auto_now_add=True)),
|
||||
("max_duration_in_seconds", models.IntegerField(default=43200)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
|
@ -98,6 +98,17 @@ class Agent(BaseModel):
|
|||
slug = models.CharField(max_length=200)
|
||||
|
||||
|
||||
class ProcessLock(BaseModel):
|
||||
class Operation(models.TextChoices):
|
||||
UPDATE_EMBEDDINGS = "update_embeddings"
|
||||
|
||||
# We need to make sure that some operations are thread-safe. To do so, add locks for potentially shared operations.
|
||||
# For example, we need to make sure that only one process is updating the embeddings at a time.
|
||||
name = models.CharField(max_length=200, choices=Operation.choices)
|
||||
started_at = models.DateTimeField(auto_now_add=True)
|
||||
max_duration_in_seconds = models.IntegerField(default=60 * 60 * 12) # 12 hours
|
||||
|
||||
|
||||
@receiver(pre_save, sender=Agent)
|
||||
def verify_agent(sender, instance, **kwargs):
|
||||
# check if this is a new instance
|
||||
|
|
Loading…
Reference in a new issue