mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 16:14:21 +00:00
Only redirect to next url relative to current domain
This commit is contained in:
parent
86a3505d89
commit
4daf16e5f9
3 changed files with 21 additions and 8 deletions
|
@ -18,7 +18,7 @@ from khoj.database.adapters import (
|
||||||
get_or_create_user,
|
get_or_create_user,
|
||||||
)
|
)
|
||||||
from khoj.routers.email import send_welcome_email
|
from khoj.routers.email import send_welcome_email
|
||||||
from khoj.routers.helpers import update_telemetry_state
|
from khoj.routers.helpers import get_next_url, update_telemetry_state
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -94,7 +94,7 @@ async def delete_token(request: Request, token: str):
|
||||||
@auth_router.post("/redirect")
|
@auth_router.post("/redirect")
|
||||||
async def auth(request: Request):
|
async def auth(request: Request):
|
||||||
form = await request.form()
|
form = await request.form()
|
||||||
next_url = request.query_params.get("next", "/")
|
next_url = get_next_url(request)
|
||||||
for q in request.query_params:
|
for q in request.query_params:
|
||||||
if not q == "next":
|
if not q == "next":
|
||||||
next_url += f"&{q}={request.query_params[q]}"
|
next_url += f"&{q}={request.query_params[q]}"
|
||||||
|
@ -104,11 +104,11 @@ async def auth(request: Request):
|
||||||
csrf_token_cookie = request.cookies.get("g_csrf_token")
|
csrf_token_cookie = request.cookies.get("g_csrf_token")
|
||||||
if not csrf_token_cookie:
|
if not csrf_token_cookie:
|
||||||
logger.info("Missing CSRF token. Redirecting user to login page")
|
logger.info("Missing CSRF token. Redirecting user to login page")
|
||||||
return RedirectResponse(url=f"{next_url}")
|
return RedirectResponse(url=next_url)
|
||||||
csrf_token_body = form.get("g_csrf_token")
|
csrf_token_body = form.get("g_csrf_token")
|
||||||
if not csrf_token_body:
|
if not csrf_token_body:
|
||||||
logger.info("Missing CSRF token body. Redirecting user to login page")
|
logger.info("Missing CSRF token body. Redirecting user to login page")
|
||||||
return RedirectResponse(url=f"{next_url}")
|
return RedirectResponse(url=next_url)
|
||||||
if csrf_token_cookie != csrf_token_body:
|
if csrf_token_cookie != csrf_token_body:
|
||||||
return Response("Invalid CSRF token", status_code=400)
|
return Response("Invalid CSRF token", status_code=400)
|
||||||
|
|
||||||
|
@ -130,9 +130,9 @@ async def auth(request: Request):
|
||||||
metadata={"user_id": str(khoj_user.uuid)},
|
metadata={"user_id": str(khoj_user.uuid)},
|
||||||
)
|
)
|
||||||
logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}")
|
logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}")
|
||||||
return RedirectResponse(url=f"{next_url}", status_code=HTTP_302_FOUND)
|
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)
|
||||||
|
|
||||||
return RedirectResponse(url=f"{next_url}", status_code=HTTP_302_FOUND)
|
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)
|
||||||
|
|
||||||
|
|
||||||
@auth_router.get("/logout")
|
@auth_router.get("/logout")
|
||||||
|
|
|
@ -21,7 +21,7 @@ from typing import (
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from urllib.parse import parse_qs, urlencode
|
from urllib.parse import parse_qs, urlencode, urljoin, urlparse
|
||||||
|
|
||||||
import cron_descriptor
|
import cron_descriptor
|
||||||
import openai
|
import openai
|
||||||
|
@ -161,6 +161,18 @@ def update_telemetry_state(
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_next_url(request: Request) -> str:
|
||||||
|
"Construct next url relative to current domain from request"
|
||||||
|
next_url_param = urlparse(request.query_params.get("next", "/"))
|
||||||
|
next_path = "/" # default next path
|
||||||
|
# If relative path or absolute path to current domain
|
||||||
|
if is_none_or_empty(next_url_param.scheme) or next_url_param.netloc == request.base_url.netloc:
|
||||||
|
# Use path in next query param
|
||||||
|
next_path = next_url_param.path
|
||||||
|
# Construct absolute url using current domain and next path from request
|
||||||
|
return urljoin(str(request.base_url).rstrip("/"), next_path)
|
||||||
|
|
||||||
|
|
||||||
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||||
chat_history = ""
|
chat_history = ""
|
||||||
for chat in conversation_history.get("chat", [])[-n:]:
|
for chat in conversation_history.get("chat", [])[-n:]:
|
||||||
|
|
|
@ -21,6 +21,7 @@ from khoj.database.adapters import (
|
||||||
get_user_subscription_state,
|
get_user_subscription_state,
|
||||||
)
|
)
|
||||||
from khoj.database.models import KhojUser
|
from khoj.database.models import KhojUser
|
||||||
|
from khoj.routers.helpers import get_next_url
|
||||||
from khoj.routers.notion import get_notion_auth_url
|
from khoj.routers.notion import get_notion_auth_url
|
||||||
from khoj.routers.twilio import is_twilio_enabled
|
from khoj.routers.twilio import is_twilio_enabled
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
|
@ -118,7 +119,7 @@ def chat_page(request: Request):
|
||||||
|
|
||||||
@web_client.get("/login", response_class=FileResponse)
|
@web_client.get("/login", response_class=FileResponse)
|
||||||
def login_page(request: Request):
|
def login_page(request: Request):
|
||||||
next_url = request.query_params.get("next", "/")
|
next_url = get_next_url(request)
|
||||||
if request.user.is_authenticated:
|
if request.user.is_authenticated:
|
||||||
return RedirectResponse(url=next_url)
|
return RedirectResponse(url=next_url)
|
||||||
google_client_id = os.environ.get("GOOGLE_CLIENT_ID")
|
google_client_id = os.environ.get("GOOGLE_CLIENT_ID")
|
||||||
|
|
Loading…
Add table
Reference in a new issue