From eb1b21baaaef32c0e474166bd64521b0ae976150 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 23 Nov 2024 11:55:34 -0800 Subject: [PATCH] Add a new sign in modal that is triggered from the login prompt screen, rather than redirecting user to another screen to sign in --- .../components/loginPrompt/loginPrompt.tsx | 171 +++++++++++++++--- src/interface/web/app/layout.tsx | 4 +- src/interface/web/yarn.lock | 1 + src/khoj/routers/auth.py | 54 ++++-- 4 files changed, 188 insertions(+), 42 deletions(-) diff --git a/src/interface/web/app/components/loginPrompt/loginPrompt.tsx b/src/interface/web/app/components/loginPrompt/loginPrompt.tsx index 18cdd48a..3871f48a 100644 --- a/src/interface/web/app/components/loginPrompt/loginPrompt.tsx +++ b/src/interface/web/app/components/loginPrompt/loginPrompt.tsx @@ -8,40 +8,159 @@ import { AlertDialogHeader, AlertDialogTitle, } from "@/components/ui/alert-dialog"; +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { ArrowLeft, GoogleCardboardLogo, GoogleLogo, Spinner } from "@phosphor-icons/react"; import Link from "next/link"; +import { useState } from "react"; +import useSWR from "swr"; export interface LoginPromptProps { loginRedirectMessage: string; onOpenChange: (open: boolean) => void; } +const fetcher = (url: string) => fetch(url).then((res) => res.json()); + +interface Provider { + client_id: string; + redirect_uri: string; +} + +interface CredentialsData { + [provider: string]: Provider; +} + export default function LoginPrompt(props: LoginPromptProps) { + const { data, error, isLoading } = useSWR("/auth/oauth/metadata", fetcher); + + const [useEmailSignIn, setUseEmailSignIn] = useState(false); + + const [email, setEmail] = useState(""); + const [checkEmail, setCheckEmail] = useState(false); + + const handleGoogleSignIn = () => { + if (!data?.google?.client_id || !data?.google?.redirect_uri) return; + + // Create full redirect URL using current origin + const fullRedirectUri = `${window.location.origin}${data.google.redirect_uri}`; + + const params = new URLSearchParams({ + client_id: data.google.client_id, + redirect_uri: fullRedirectUri, + response_type: "code", + scope: "email profile openid", + state: window.location.pathname, + access_type: "offline", + prompt: "consent select_account", + include_granted_scopes: "true", + }); + + window.location.href = `https://accounts.google.com/o/oauth2/v2/auth?${params}`; + }; + + function handleMagicLinkSignIn() { + fetch("/auth/magic", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ email: email }), + }) + .then((res) => { + if (res.ok) { + setCheckEmail(true); + return res.json(); + } else { + throw new Error("Failed to send magic link"); + } + }) + .then((data) => { + console.log(data); + }) + .catch((err) => { + console.error(err); + }); + } + return ( - - - - Sign in to Khoj to continue - - - {props.loginRedirectMessage}. By logging in, you agree to our{" "} - Terms of Service. - - - Dismiss - { - window.location.href = `/login?next=${encodeURIComponent(window.location.pathname)}`; - }} - > - - {" "} - {/* Redirect to login page */} - Login - - - - - + + +
+ + Sign in to Khoj to continue + + + {props.loginRedirectMessage}. + + {useEmailSignIn && ( +
+ + setEmail(e.target.value)} + /> + +
+ )} + {!useEmailSignIn && ( +
+ + + +
+ )} + + By logging in, you agree to our{" "} + Terms of Service. + +
+
+ +
+
+
); } diff --git a/src/interface/web/app/layout.tsx b/src/interface/web/app/layout.tsx index 0efa1e89..1d0d970f 100644 --- a/src/interface/web/app/layout.tsx +++ b/src/interface/web/app/layout.tsx @@ -40,7 +40,7 @@ export default function RootLayout({ }>) { return ( - + > */} {children} ); diff --git a/src/interface/web/yarn.lock b/src/interface/web/yarn.lock index 690eab76..a41f9ea6 100644 --- a/src/interface/web/yarn.lock +++ b/src/interface/web/yarn.lock @@ -4174,6 +4174,7 @@ string-argv@~0.3.2: integrity sha512-aqD2Q0144Z+/RqG52NeHEkZauTAUWJO8c6yTftGJKO3Tja5tUgIfmIl6kExvhtxSDP7fXB6DvzkfMpCd/F3G+Q== "string-width-cjs@npm:string-width@^4.2.0", string-width@^4.1.0: + name string-width-cjs version "4.2.3" resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py index 42df8bf9..8fc2b9be 100644 --- a/src/khoj/routers/auth.py +++ b/src/khoj/routers/auth.py @@ -4,6 +4,7 @@ import logging import os from typing import Optional +import requests from fastapi import APIRouter from pydantic import BaseModel, EmailStr from starlette.authentication import requires @@ -139,26 +140,40 @@ async def delete_token(request: Request, token: str): return await delete_khoj_token(user=request.user.object, token=token) -@auth_router.post("/redirect") +@auth_router.get("/redirect") async def auth(request: Request): - form = await request.form() next_url = get_next_url(request) for q in request.query_params: + if q in ["code", "state", "scope", "authuser", "prompt", "session_state", "access_type"]: + continue if not q == "next": next_url += f"&{q}={request.query_params[q]}" - credential = form.get("credential") + code = request.query_params.get("code") - csrf_token_cookie = request.cookies.get("g_csrf_token") - if not csrf_token_cookie: - logger.info("Missing CSRF token. Redirecting user to login page") - return RedirectResponse(url=next_url) - csrf_token_body = form.get("g_csrf_token") - if not csrf_token_body: - logger.info("Missing CSRF token body. Redirecting user to login page") - return RedirectResponse(url=next_url) - if csrf_token_cookie != csrf_token_body: - return Response("Invalid CSRF token", status_code=400) + # 1. Construct the full redirect URI including domain + base_url = str(request.base_url).rstrip("/") + redirect_uri = f"{base_url}{request.app.url_path_for('auth')}" + + verified_data = requests.post( + "https://oauth2.googleapis.com/token", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "code": code, + "client_id": os.environ["GOOGLE_CLIENT_ID"], + "client_secret": os.environ["GOOGLE_CLIENT_SECRET"], + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + }, + ) + + verified_data.raise_for_status() + + credential = verified_data.json().get("id_token") + + if not credential: + logger.error("Missing id_token in OAuth response") + return RedirectResponse(url="/login?error=invalid_token", status_code=HTTP_302_FOUND) try: idinfo = id_token.verify_oauth2_token(credential, google_requests.Request(), os.environ["GOOGLE_CLIENT_ID"]) @@ -178,7 +193,6 @@ async def auth(request: Request): metadata={"server_id": str(khoj_user.uuid)}, ) logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}") - return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND) return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND) @@ -187,3 +201,15 @@ async def auth(request: Request): async def logout(request: Request): request.session.pop("user", None) return RedirectResponse(url="/") + + +@auth_router.get("/oauth/metadata") +async def oauth_metadata(request: Request): + redirect_uri = str(request.app.url_path_for("auth")) + + return { + "google": { + "client_id": os.environ.get("GOOGLE_CLIENT_ID"), + "redirect_uri": f"{redirect_uri}", + } + }