Auto redirect requests to use HTTPS if server is using SSL certs

This commit is contained in:
Debanjum 2024-12-16 12:49:55 -08:00
parent 132f2c987a
commit 9c64275dec
2 changed files with 5 additions and 3 deletions

View file

@ -24,6 +24,7 @@ from starlette.concurrency import run_in_threadpool
from starlette.middleware import Middleware from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from starlette.requests import HTTPConnection from starlette.requests import HTTPConnection
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import ASGIApp, Receive, Scope, Send
@ -43,7 +44,6 @@ from khoj.database.adapters import (
from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.api_content import configure_content, configure_search from khoj.routers.api_content import configure_content, configure_search
from khoj.routers.helpers import update_telemetry_state
from khoj.routers.twilio import is_twilio_enabled from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import SearchType from khoj.utils.config import SearchType
@ -343,7 +343,7 @@ def configure_routes(app):
logger.info("📞 Enabled Twilio") logger.info("📞 Enabled Twilio")
def configure_middleware(app): def configure_middleware(app, ssl_enabled: bool = False):
class NextJsMiddleware(Middleware): class NextJsMiddleware(Middleware):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http" and scope["path"].startswith("/_next"): if scope["type"] == "http" and scope["path"].startswith("/_next"):
@ -354,6 +354,8 @@ def configure_middleware(app):
super().__init__(app) super().__init__(app)
self.app = app self.app = app
if ssl_enabled:
app.add_middleware(HTTPSRedirectMiddleware)
app.add_middleware(AsyncCloseConnectionsMiddleware) app.add_middleware(AsyncCloseConnectionsMiddleware)
app.add_middleware(AuthenticationMiddleware, backend=UserAuthenticationBackend()) app.add_middleware(AuthenticationMiddleware, backend=UserAuthenticationBackend())
app.add_middleware(NextJsMiddleware) app.add_middleware(NextJsMiddleware)

View file

@ -191,7 +191,7 @@ def run(should_start_server=True):
app.mount(f"/static", StaticFiles(directory=static_dir), name=static_dir) app.mount(f"/static", StaticFiles(directory=static_dir), name=static_dir)
# Configure Middleware # Configure Middleware
configure_middleware(app) configure_middleware(app, state.ssl_config)
initialize_server(args.config) initialize_server(args.config)