Merge pull request #710 from khoj-ai/add-run-with-process-lock-and-fix-edge-cases

Extract run with process lock logic into func. Use it to re-index content
This commit is contained in:
sabaimran 2024-04-17 01:29:02 -07:00 committed by GitHub
commit c9a8abafa4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 116 additions and 61 deletions

View file

@ -307,17 +307,7 @@ def configure_middleware(app):
app.add_middleware(SessionMiddleware, secret_key=os.environ.get("KHOJ_DJANGO_SECRET_KEY", "!secret")) app.add_middleware(SessionMiddleware, secret_key=os.environ.get("KHOJ_DJANGO_SECRET_KEY", "!secret"))
@schedule.repeat(schedule.every(22).to(25).hours)
def update_content_index(): def update_content_index():
try:
if ProcessLockAdapters.is_process_locked(ProcessLock.Operation.UPDATE_EMBEDDINGS):
logger.info("🔒 Skipping update content index due to lock")
return
ProcessLockAdapters.set_process_lock(
ProcessLock.Operation.UPDATE_EMBEDDINGS, max_duration_in_seconds=60 * 60 * 2
)
with timer("📬 Updating content index via Scheduler"):
for user in get_all_users(): for user in get_all_users():
all_files = collect_files(user=user) all_files = collect_files(user=user)
success = configure_content(all_files, user=user) success = configure_content(all_files, user=user)
@ -325,12 +315,14 @@ def update_content_index():
success = configure_content(all_files, user=None) success = configure_content(all_files, user=None)
if not success: if not success:
raise RuntimeError("Failed to update content index") raise RuntimeError("Failed to update content index")
logger.info("📪 Content index updated via Scheduler") logger.info("📪 Content index updated via Scheduler")
ProcessLockAdapters.remove_process_lock(ProcessLock.Operation.UPDATE_EMBEDDINGS)
except Exception as e: @schedule.repeat(schedule.every(22).to(25).hours)
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True) def update_content_index_regularly():
ProcessLockAdapters.run_with_lock(
update_content_index, ProcessLock.Operation.UPDATE_EMBEDDINGS, max_duration_in_seconds=60 * 60 * 2
)
def configure_search_types(): def configure_search_types():

View file

@ -5,7 +5,7 @@ import secrets
import sys import sys
from datetime import date, datetime, timedelta, timezone from datetime import date, datetime, timedelta, timezone
from enum import Enum from enum import Enum
from typing import List, Optional, Type from typing import Callable, List, Optional, Type
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.db import SessionStore from django.contrib.sessions.backends.db import SessionStore
@ -46,7 +46,7 @@ from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.word_filter import WordFilter
from khoj.utils import state from khoj.utils import state
from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import generate_random_name, is_none_or_empty from khoj.utils.helpers import generate_random_name, is_none_or_empty, timer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -421,6 +421,7 @@ class ProcessLockAdapters:
tz=timezone.utc tz=timezone.utc
): ):
process_lock.delete() process_lock.delete()
logger.info(f"🔓 Deleted stale {process_name} process lock on timeout")
return False return False
return True return True
@ -428,6 +429,31 @@ class ProcessLockAdapters:
def remove_process_lock(process_name: str): def remove_process_lock(process_name: str):
return ProcessLock.objects.filter(name=process_name).delete() return ProcessLock.objects.filter(name=process_name).delete()
@staticmethod
def run_with_lock(func: Callable, operation: ProcessLock.Operation, max_duration_in_seconds: int = 600):
# Exit early if process lock is already taken
if ProcessLockAdapters.is_process_locked(operation):
logger.info(f"🔒 Skip executing {func} as {operation} lock is already taken")
return
success = False
try:
# Set process lock
ProcessLockAdapters.set_process_lock(operation, max_duration_in_seconds)
logger.info(f"🔐 Locked {operation} to execute {func}")
# Execute Function
with timer(f"🔒 Run {func} with {operation} process lock", logger):
func()
success = True
except Exception as e:
logger.error(f"🚨 Error executing {func} with {operation} process lock: {e}", exc_info=True)
success = False
finally:
# Remove Process Lock
ProcessLockAdapters.remove_process_lock(operation)
logger.info(f"🔓 Unlocked {operation} process after executing {func} {'Succeeded' if success else 'Failed'}")
class ClientApplicationAdapters: class ClientApplicationAdapters:
@staticmethod @staticmethod

View file

@ -53,17 +53,8 @@ class NotionContentConfig(ConfigBase):
token: str token: str
class ImageContentConfig(ConfigBase):
input_directories: Optional[List[Path]] = None
input_filter: Optional[List[str]] = None
embeddings_file: Path
use_xmp_metadata: bool
batch_size: int
class ContentConfig(ConfigBase): class ContentConfig(ConfigBase):
org: Optional[TextContentConfig] = None org: Optional[TextContentConfig] = None
image: Optional[ImageContentConfig] = None
markdown: Optional[TextContentConfig] = None markdown: Optional[TextContentConfig] = None
pdf: Optional[TextContentConfig] = None pdf: Optional[TextContentConfig] = None
plaintext: Optional[TextContentConfig] = None plaintext: Optional[TextContentConfig] = None

View file

@ -30,16 +30,12 @@ from khoj.utils import fs_syncer, state
from khoj.utils.config import SearchModels from khoj.utils.config import SearchModels
from khoj.utils.constants import web_directory from khoj.utils.constants import web_directory
from khoj.utils.helpers import resolve_absolute_path from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig
ContentConfig,
ImageContentConfig,
ImageSearchConfig,
SearchConfig,
)
from tests.helpers import ( from tests.helpers import (
ChatModelOptionsFactory, ChatModelOptionsFactory,
OfflineChatProcessorConversationConfigFactory, OfflineChatProcessorConversationConfigFactory,
OpenAIProcessorConversationConfigFactory, OpenAIProcessorConversationConfigFactory,
ProcessLockFactory,
SubscriptionFactory, SubscriptionFactory,
UserConversationProcessorConfigFactory, UserConversationProcessorConfigFactory,
UserFactory, UserFactory,
@ -211,6 +207,12 @@ def search_models(search_config: SearchConfig):
return search_models return search_models
@pytest.mark.django_db
@pytest.fixture
def default_process_lock():
return ProcessLockFactory()
@pytest.fixture @pytest.fixture
def anyio_backend(): def anyio_backend():
return "asyncio" return "asyncio"
@ -223,13 +225,6 @@ def content_config(tmp_path_factory, search_models: SearchModels, default_user:
# Generate Image Embeddings from Test Images # Generate Image Embeddings from Test Images
content_config = ContentConfig() content_config = ContentConfig()
content_config.image = ImageContentConfig(
input_filter=None,
input_directories=["tests/data/images"],
embeddings_file=content_dir.joinpath("image_embeddings.pt"),
batch_size=1,
use_xmp_metadata=False,
)
LocalOrgConfig.objects.create( LocalOrgConfig.objects.create(
input_files=None, input_files=None,

View file

@ -11,6 +11,7 @@ from khoj.database.models import (
KhojUser, KhojUser,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
ProcessLock,
SearchModelConfig, SearchModelConfig,
Subscription, Subscription,
UserConversationConfig, UserConversationConfig,
@ -93,3 +94,10 @@ class SubscriptionFactory(factory.django.DjangoModelFactory):
type = "standard" type = "standard"
is_recurring = False is_recurring = False
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d")) renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
class ProcessLockFactory(factory.django.DjangoModelFactory):
class Meta:
model = ProcessLock
name = "test_lock"

View file

@ -273,7 +273,7 @@ def test_get_api_config_types(client, sample_org_data, default_user: KhojUser):
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == ["all", "org", "image", "plaintext"] assert response.json() == ["all", "org", "plaintext"]
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------

58
tests/test_db_lock.py Normal file
View file

@ -0,0 +1,58 @@
import time
import pytest
from khoj.database.adapters import ProcessLockAdapters
from khoj.database.models import ProcessLock
from tests.helpers import ProcessLockFactory
@pytest.mark.django_db(transaction=True)
def test_process_lock(default_process_lock):
# Arrange
lock: ProcessLock = default_process_lock
# Assert
assert True == ProcessLockAdapters.is_process_locked(lock.name)
@pytest.mark.django_db(transaction=True)
def test_expired_process_lock():
# Arrange
lock: ProcessLock = ProcessLockFactory(name="test_expired_lock", max_duration_in_seconds=2)
# Act
time.sleep(3)
# Assert
assert False == ProcessLockAdapters.is_process_locked(lock.name)
@pytest.mark.django_db(transaction=True)
def test_in_progress_lock(default_process_lock):
# Arrange
lock: ProcessLock = default_process_lock
# Act
ProcessLockAdapters.run_with_lock(lock.name, lambda: time.sleep(2))
# Assert
assert True == ProcessLockAdapters.is_process_locked(lock.name)
@pytest.mark.django_db(transaction=True)
def test_run_with_completed():
# Arrange
ProcessLockAdapters.run_with_lock("test_run_with", lambda: time.sleep(2))
# Act
time.sleep(4)
# Assert
assert False == ProcessLockAdapters.is_process_locked("test_run_with")
@pytest.mark.django_db(transaction=True)
def test_nonexistent_lock():
# Assert
assert False == ProcessLockAdapters.is_process_locked("nonexistent_lock")

View file

@ -1,15 +0,0 @@
import pytest
from khoj.utils.rawconfig import ImageContentConfig, TextContentConfig
# Test
# ----------------------------------------------------------------------------------------------------
def test_input_filter_or_directories_required_in_image_content_config():
# Act
with pytest.raises(ValueError):
ImageContentConfig(
input_directories=None,
input_filter=None,
embeddings_file="note_embeddings.pt",
)