From 9c64275dec0258041042e93e34adc4ed84d9bc73 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Mon, 16 Dec 2024 12:49:55 -0800 Subject: [PATCH] Auto redirect requests to use HTTPS if server is using SSL certs --- src/khoj/configure.py | 6 ++++-- src/khoj/main.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 33008b6c..900b3e30 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -24,6 +24,7 @@ from starlette.concurrency import run_in_threadpool from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware from starlette.middleware.sessions import SessionMiddleware from starlette.requests import HTTPConnection from starlette.types import ASGIApp, Receive, Scope, Send @@ -43,7 +44,6 @@ from khoj.database.adapters import ( from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.routers.api_content import configure_content, configure_search -from khoj.routers.helpers import update_telemetry_state from khoj.routers.twilio import is_twilio_enabled from khoj.utils import constants, state from khoj.utils.config import SearchType @@ -343,7 +343,7 @@ def configure_routes(app): logger.info("📞 Enabled Twilio") -def configure_middleware(app): +def configure_middleware(app, ssl_enabled: bool = False): class NextJsMiddleware(Middleware): async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "http" and scope["path"].startswith("/_next"): @@ -354,6 +354,8 @@ def configure_middleware(app): super().__init__(app) self.app = app + if ssl_enabled: + app.add_middleware(HTTPSRedirectMiddleware) app.add_middleware(AsyncCloseConnectionsMiddleware) app.add_middleware(AuthenticationMiddleware, backend=UserAuthenticationBackend()) app.add_middleware(NextJsMiddleware) diff --git a/src/khoj/main.py b/src/khoj/main.py index dac4694b..771816ef 100644 --- a/src/khoj/main.py +++ b/src/khoj/main.py @@ -191,7 +191,7 @@ def run(should_start_server=True): app.mount(f"/static", StaticFiles(directory=static_dir), name=static_dir) # Configure Middleware - configure_middleware(app) + configure_middleware(app, state.ssl_config) initialize_server(args.config)