Only redirect to next url relative to current domain

This commit is contained in:
Debanjum Singh Solanky 2024-06-14 12:08:20 +05:30
parent 86a3505d89
commit 4daf16e5f9
3 changed files with 21 additions and 8 deletions

View file

@ -18,7 +18,7 @@ from khoj.database.adapters import (
get_or_create_user,
)
from khoj.routers.email import send_welcome_email
from khoj.routers.helpers import update_telemetry_state
from khoj.routers.helpers import get_next_url, update_telemetry_state
from khoj.utils import state
logger = logging.getLogger(__name__)
@ -94,7 +94,7 @@ async def delete_token(request: Request, token: str):
@auth_router.post("/redirect")
async def auth(request: Request):
form = await request.form()
next_url = request.query_params.get("next", "/")
next_url = get_next_url(request)
for q in request.query_params:
if not q == "next":
next_url += f"&{q}={request.query_params[q]}"
@ -104,11 +104,11 @@ async def auth(request: Request):
csrf_token_cookie = request.cookies.get("g_csrf_token")
if not csrf_token_cookie:
logger.info("Missing CSRF token. Redirecting user to login page")
return RedirectResponse(url=f"{next_url}")
return RedirectResponse(url=next_url)
csrf_token_body = form.get("g_csrf_token")
if not csrf_token_body:
logger.info("Missing CSRF token body. Redirecting user to login page")
return RedirectResponse(url=f"{next_url}")
return RedirectResponse(url=next_url)
if csrf_token_cookie != csrf_token_body:
return Response("Invalid CSRF token", status_code=400)
@ -130,9 +130,9 @@ async def auth(request: Request):
metadata={"user_id": str(khoj_user.uuid)},
)
logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}")
return RedirectResponse(url=f"{next_url}", status_code=HTTP_302_FOUND)
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)
return RedirectResponse(url=f"{next_url}", status_code=HTTP_302_FOUND)
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)
@auth_router.get("/logout")

View file

@ -21,7 +21,7 @@ from typing import (
Tuple,
Union,
)
from urllib.parse import parse_qs, urlencode
from urllib.parse import parse_qs, urlencode, urljoin, urlparse
import cron_descriptor
import openai
@ -161,6 +161,18 @@ def update_telemetry_state(
]
def get_next_url(request: Request) -> str:
"Construct next url relative to current domain from request"
next_url_param = urlparse(request.query_params.get("next", "/"))
next_path = "/" # default next path
# If relative path or absolute path to current domain
if is_none_or_empty(next_url_param.scheme) or next_url_param.netloc == request.base_url.netloc:
# Use path in next query param
next_path = next_url_param.path
# Construct absolute url using current domain and next path from request
return urljoin(str(request.base_url).rstrip("/"), next_path)
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
chat_history = ""
for chat in conversation_history.get("chat", [])[-n:]:

View file

@ -21,6 +21,7 @@ from khoj.database.adapters import (
get_user_subscription_state,
)
from khoj.database.models import KhojUser
from khoj.routers.helpers import get_next_url
from khoj.routers.notion import get_notion_auth_url
from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state
@ -118,7 +119,7 @@ def chat_page(request: Request):
@web_client.get("/login", response_class=FileResponse)
def login_page(request: Request):
next_url = request.query_params.get("next", "/")
next_url = get_next_url(request)
if request.user.is_authenticated:
return RedirectResponse(url=next_url)
google_client_id = os.environ.get("GOOGLE_CLIENT_ID")