mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Only redirect to next url relative to current domain
This commit is contained in:
parent
86a3505d89
commit
4daf16e5f9
3 changed files with 21 additions and 8 deletions
|
@ -18,7 +18,7 @@ from khoj.database.adapters import (
|
|||
get_or_create_user,
|
||||
)
|
||||
from khoj.routers.email import send_welcome_email
|
||||
from khoj.routers.helpers import update_telemetry_state
|
||||
from khoj.routers.helpers import get_next_url, update_telemetry_state
|
||||
from khoj.utils import state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -94,7 +94,7 @@ async def delete_token(request: Request, token: str):
|
|||
@auth_router.post("/redirect")
|
||||
async def auth(request: Request):
|
||||
form = await request.form()
|
||||
next_url = request.query_params.get("next", "/")
|
||||
next_url = get_next_url(request)
|
||||
for q in request.query_params:
|
||||
if not q == "next":
|
||||
next_url += f"&{q}={request.query_params[q]}"
|
||||
|
@ -104,11 +104,11 @@ async def auth(request: Request):
|
|||
csrf_token_cookie = request.cookies.get("g_csrf_token")
|
||||
if not csrf_token_cookie:
|
||||
logger.info("Missing CSRF token. Redirecting user to login page")
|
||||
return RedirectResponse(url=f"{next_url}")
|
||||
return RedirectResponse(url=next_url)
|
||||
csrf_token_body = form.get("g_csrf_token")
|
||||
if not csrf_token_body:
|
||||
logger.info("Missing CSRF token body. Redirecting user to login page")
|
||||
return RedirectResponse(url=f"{next_url}")
|
||||
return RedirectResponse(url=next_url)
|
||||
if csrf_token_cookie != csrf_token_body:
|
||||
return Response("Invalid CSRF token", status_code=400)
|
||||
|
||||
|
@ -130,9 +130,9 @@ async def auth(request: Request):
|
|||
metadata={"user_id": str(khoj_user.uuid)},
|
||||
)
|
||||
logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}")
|
||||
return RedirectResponse(url=f"{next_url}", status_code=HTTP_302_FOUND)
|
||||
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)
|
||||
|
||||
return RedirectResponse(url=f"{next_url}", status_code=HTTP_302_FOUND)
|
||||
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)
|
||||
|
||||
|
||||
@auth_router.get("/logout")
|
||||
|
|
|
@ -21,7 +21,7 @@ from typing import (
|
|||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import parse_qs, urlencode
|
||||
from urllib.parse import parse_qs, urlencode, urljoin, urlparse
|
||||
|
||||
import cron_descriptor
|
||||
import openai
|
||||
|
@ -161,6 +161,18 @@ def update_telemetry_state(
|
|||
]
|
||||
|
||||
|
||||
def get_next_url(request: Request) -> str:
|
||||
"Construct next url relative to current domain from request"
|
||||
next_url_param = urlparse(request.query_params.get("next", "/"))
|
||||
next_path = "/" # default next path
|
||||
# If relative path or absolute path to current domain
|
||||
if is_none_or_empty(next_url_param.scheme) or next_url_param.netloc == request.base_url.netloc:
|
||||
# Use path in next query param
|
||||
next_path = next_url_param.path
|
||||
# Construct absolute url using current domain and next path from request
|
||||
return urljoin(str(request.base_url).rstrip("/"), next_path)
|
||||
|
||||
|
||||
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||
chat_history = ""
|
||||
for chat in conversation_history.get("chat", [])[-n:]:
|
||||
|
|
|
@ -21,6 +21,7 @@ from khoj.database.adapters import (
|
|||
get_user_subscription_state,
|
||||
)
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.routers.helpers import get_next_url
|
||||
from khoj.routers.notion import get_notion_auth_url
|
||||
from khoj.routers.twilio import is_twilio_enabled
|
||||
from khoj.utils import constants, state
|
||||
|
@ -118,7 +119,7 @@ def chat_page(request: Request):
|
|||
|
||||
@web_client.get("/login", response_class=FileResponse)
|
||||
def login_page(request: Request):
|
||||
next_url = request.query_params.get("next", "/")
|
||||
next_url = get_next_url(request)
|
||||
if request.user.is_authenticated:
|
||||
return RedirectResponse(url=next_url)
|
||||
google_client_id = os.environ.get("GOOGLE_CLIENT_ID")
|
||||
|
|
Loading…
Reference in a new issue