Apply mitigations for piling up open connections

- Because we're using a FastAPI api framework with a Django ORM, we're running into some interesting conditions around connection pooling and clean-up. We're ending up with a large pile-up of open, stale connections to the DB recurringly when the server has been running for a while. To mitigate this problem, given starlette and django run in different python threads, add a middleware that will go and call the connection clean up method in each of the threads.
This commit is contained in:
sabaimran 2024-07-09 12:22:58 +05:30
parent 0b1b262512
commit 4471c1e37f
3 changed files with 40 additions and 12 deletions

View file

@ -65,7 +65,7 @@ dependencies = [
"tenacity == 8.3.0", "tenacity == 8.3.0",
"anyio == 3.7.1", "anyio == 3.7.1",
"pymupdf >= 1.23.5", "pymupdf >= 1.23.5",
"django == 4.2.11", "django == 5.0.6",
"authlib == 1.2.1", "authlib == 1.2.1",
"llama-cpp-python == 0.2.76", "llama-cpp-python == 0.2.76",
"itsdangerous == 2.1.2", "itsdangerous == 2.1.2",

View file

@ -110,6 +110,8 @@ TEMPLATES = [
ASGI_APPLICATION = "app.asgi.application" ASGI_APPLICATION = "app.asgi.application"
CLOSE_CONNECTIONS_AFTER_REQUEST = True
# Database # Database
# https://docs.djangoproject.com/en/4.2/ref/settings/#databases # https://docs.djangoproject.com/en/4.2/ref/settings/#databases
DATA_UPLOAD_MAX_NUMBER_FIELDS = 20000 DATA_UPLOAD_MAX_NUMBER_FIELDS = 20000

View file

@ -1,13 +1,16 @@
import json import json
import logging import logging
import os import os
from datetime import datetime, timedelta from datetime import datetime
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
import openai import openai
import requests import requests
import schedule import schedule
from asgiref.sync import sync_to_async
from django.conf import settings
from django.db import close_old_connections, connections
from django.utils.timezone import make_aware from django.utils.timezone import make_aware
from fastapi import Response from fastapi import Response
from starlette.authentication import ( from starlette.authentication import (
@ -16,8 +19,10 @@ from starlette.authentication import (
SimpleUser, SimpleUser,
UnauthenticatedUser, UnauthenticatedUser,
) )
from starlette.concurrency import run_in_threadpool
from starlette.middleware import Middleware from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from starlette.requests import HTTPConnection from starlette.requests import HTTPConnection
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import ASGIApp, Receive, Scope, Send
@ -42,7 +47,7 @@ from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import SearchType from khoj.utils.config import SearchType
from khoj.utils.fs_syncer import collect_files from khoj.utils.fs_syncer import collect_files
from khoj.utils.helpers import is_none_or_empty, telemetry_disabled, timer from khoj.utils.helpers import is_none_or_empty, telemetry_disabled
from khoj.utils.rawconfig import FullConfig from khoj.utils.rawconfig import FullConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,6 +60,35 @@ class AuthenticatedKhojUser(SimpleUser):
super().__init__(user.username) super().__init__(user.username)
class AsyncCloseConnectionsMiddleware(BaseHTTPMiddleware):
"""
Using this middleware to call close_old_connections() twice is a pretty yucky hack,
as it appears that run_in_threadpool (used by Starlette/FastAPI) and sync_to_async
(used by Django) have divergent behavior, ultimately acquiring the incorrect thread
in mixed sync/async which has the effect of duplicating connections.
We could fix the duplicate connections too if we normalized the thread behavior,
but at minimum we need to clean up connections in each case to prevent persistent
"InterfaceError: connection already closed" errors when the database connection is
reset via a database restart or something -- so here we are!
If we always use smart_sync_to_async(), this double calling isn't necessary, but
depending on what levels of abstraction we introduce, we might silently break the
assumptions. Better to be safe than sorry!
Attribution: https://gist.github.com/bryanhelmig/6fb091f23c1a4b7462dddce51cfaa1ca
"""
async def dispatch(self, request, call_next):
await run_in_threadpool(close_old_connections)
await sync_to_async(close_old_connections)()
try:
response = await call_next(request)
finally:
# in tests, use @override_settings(CLOSE_CONNECTIONS_AFTER_REQUEST=True)
if getattr(settings, "CLOSE_CONNECTIONS_AFTER_REQUEST", False):
await run_in_threadpool(connections.close_all)
await sync_to_async(connections.close_all)()
return response
class UserAuthenticationBackend(AuthenticationBackend): class UserAuthenticationBackend(AuthenticationBackend):
def __init__( def __init__(
self, self,
@ -77,7 +111,6 @@ class UserAuthenticationBackend(AuthenticationBackend):
Subscription.objects.create(user=default_user, type="standard", renewal_date=renewal_date) Subscription.objects.create(user=default_user, type="standard", renewal_date=renewal_date)
async def authenticate(self, request: HTTPConnection): async def authenticate(self, request: HTTPConnection):
# Request from Web client
current_user = request.session.get("user") current_user = request.session.get("user")
if current_user and current_user.get("email"): if current_user and current_user.get("email"):
user = ( user = (
@ -318,6 +351,7 @@ def configure_middleware(app):
super().__init__(app) super().__init__(app)
self.app = app self.app = app
app.add_middleware(AsyncCloseConnectionsMiddleware)
app.add_middleware(AuthenticationMiddleware, backend=UserAuthenticationBackend()) app.add_middleware(AuthenticationMiddleware, backend=UserAuthenticationBackend())
app.add_middleware(NextJsMiddleware) app.add_middleware(NextJsMiddleware)
app.add_middleware(SessionMiddleware, secret_key=os.environ.get("KHOJ_DJANGO_SECRET_KEY", "!secret")) app.add_middleware(SessionMiddleware, secret_key=os.environ.get("KHOJ_DJANGO_SECRET_KEY", "!secret"))
@ -341,14 +375,6 @@ def update_content_index_regularly():
) )
@schedule.repeat(schedule.every(30).to(59).minutes)
def close_all_db_connections():
from django import db
db.close_old_connections()
logger.info("🔌 Closed all database connections for explicit recycling.")
def configure_search_types(): def configure_search_types():
# Extract core search types # Extract core search types
core_search_types = {e.name: e.value for e in SearchType} core_search_types = {e.name: e.value for e in SearchType}