From 4daf16e5f916641304e11d56a6071ad365c21a18 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 14 Jun 2024 12:08:20 +0530 Subject: [PATCH] Only redirect to next url relative to current domain --- src/khoj/routers/auth.py | 12 ++++++------ src/khoj/routers/helpers.py | 14 +++++++++++++- src/khoj/routers/web_client.py | 3 ++- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py index 1b8ba19b..0d879afd 100644 --- a/src/khoj/routers/auth.py +++ b/src/khoj/routers/auth.py @@ -18,7 +18,7 @@ from khoj.database.adapters import ( get_or_create_user, ) from khoj.routers.email import send_welcome_email -from khoj.routers.helpers import update_telemetry_state +from khoj.routers.helpers import get_next_url, update_telemetry_state from khoj.utils import state logger = logging.getLogger(__name__) @@ -94,7 +94,7 @@ async def delete_token(request: Request, token: str): @auth_router.post("/redirect") async def auth(request: Request): form = await request.form() - next_url = request.query_params.get("next", "/") + next_url = get_next_url(request) for q in request.query_params: if not q == "next": next_url += f"&{q}={request.query_params[q]}" @@ -104,11 +104,11 @@ async def auth(request: Request): csrf_token_cookie = request.cookies.get("g_csrf_token") if not csrf_token_cookie: logger.info("Missing CSRF token. Redirecting user to login page") - return RedirectResponse(url=f"{next_url}") + return RedirectResponse(url=next_url) csrf_token_body = form.get("g_csrf_token") if not csrf_token_body: logger.info("Missing CSRF token body. Redirecting user to login page") - return RedirectResponse(url=f"{next_url}") + return RedirectResponse(url=next_url) if csrf_token_cookie != csrf_token_body: return Response("Invalid CSRF token", status_code=400) @@ -130,9 +130,9 @@ async def auth(request: Request): metadata={"user_id": str(khoj_user.uuid)}, ) logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}") - return RedirectResponse(url=f"{next_url}", status_code=HTTP_302_FOUND) + return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND) - return RedirectResponse(url=f"{next_url}", status_code=HTTP_302_FOUND) + return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND) @auth_router.get("/logout") diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 56c1c524..febb16ea 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -21,7 +21,7 @@ from typing import ( Tuple, Union, ) -from urllib.parse import parse_qs, urlencode +from urllib.parse import parse_qs, urlencode, urljoin, urlparse import cron_descriptor import openai @@ -161,6 +161,18 @@ def update_telemetry_state( ] +def get_next_url(request: Request) -> str: + "Construct next url relative to current domain from request" + next_url_param = urlparse(request.query_params.get("next", "/")) + next_path = "/" # default next path + # If relative path or absolute path to current domain + if is_none_or_empty(next_url_param.scheme) or next_url_param.netloc == request.base_url.netloc: + # Use path in next query param + next_path = next_url_param.path + # Construct absolute url using current domain and next path from request + return urljoin(str(request.base_url).rstrip("/"), next_path) + + def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: chat_history = "" for chat in conversation_history.get("chat", [])[-n:]: diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index fa27e3d8..321353d9 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -21,6 +21,7 @@ from khoj.database.adapters import ( get_user_subscription_state, ) from khoj.database.models import KhojUser +from khoj.routers.helpers import get_next_url from khoj.routers.notion import get_notion_auth_url from khoj.routers.twilio import is_twilio_enabled from khoj.utils import constants, state @@ -118,7 +119,7 @@ def chat_page(request: Request): @web_client.get("/login", response_class=FileResponse) def login_page(request: Request): - next_url = request.query_params.get("next", "/") + next_url = get_next_url(request) if request.user.is_authenticated: return RedirectResponse(url=next_url) google_client_id = os.environ.get("GOOGLE_CLIENT_ID")