mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
[Multi-User Part 2]: Add login pages and gate access to application behind login wall (#503)
- Make most routes conditional on authentication *if anonymous mode is not enabled*. If anonymous mode is enabled, it scaffolds a default user and uses that for all application interactions. - Add a basic login page and add routes for redirecting the user if logged in
This commit is contained in:
parent
216acf545f
commit
a8a82d274a
13 changed files with 327 additions and 59 deletions
|
@ -68,6 +68,8 @@ dependencies = [
|
||||||
"httpx == 0.25.0",
|
"httpx == 0.25.0",
|
||||||
"pgvector == 0.2.3",
|
"pgvector == 0.2.3",
|
||||||
"psycopg2-binary == 2.9.9",
|
"psycopg2-binary == 2.9.9",
|
||||||
|
"google-auth == 2.23.3",
|
||||||
|
"python-multipart == 0.0.6",
|
||||||
]
|
]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
|
|
@ -73,8 +73,7 @@ async def create_google_user(token: dict) -> KhojUser:
|
||||||
|
|
||||||
|
|
||||||
async def get_user_by_token(token: dict) -> KhojUser:
|
async def get_user_by_token(token: dict) -> KhojUser:
|
||||||
user_info = token.get("userinfo")
|
google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
|
||||||
google_user = await GoogleUser.objects.filter(sub=user_info.get("sub")).select_related("user").afirst()
|
|
||||||
if not google_user:
|
if not google_user:
|
||||||
return None
|
return None
|
||||||
return google_user.user
|
return google_user.user
|
||||||
|
|
|
@ -67,7 +67,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||||
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
|
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
|
||||||
if user:
|
if user:
|
||||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||||
elif not state.anonymous_mode:
|
elif state.anonymous_mode:
|
||||||
user = await self.khojuser_manager.filter(username="default").afirst()
|
user = await self.khojuser_manager.filter(username="default").afirst()
|
||||||
if user:
|
if user:
|
||||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||||
|
@ -77,11 +77,6 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||||
|
|
||||||
def initialize_server(config: Optional[FullConfig]):
|
def initialize_server(config: Optional[FullConfig]):
|
||||||
if config is None:
|
if config is None:
|
||||||
logger.error(
|
|
||||||
f"🚨 Exiting as Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config or by editing {state.config_file}."
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
elif config is None:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"🚨 Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config, plugins or by editing {state.config_file}."
|
f"🚨 Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config, plugins or by editing {state.config_file}."
|
||||||
)
|
)
|
||||||
|
@ -230,6 +225,8 @@ def configure_conversation_processor(
|
||||||
conversation_logfile=conversation_logfile,
|
conversation_logfile=conversation_logfile,
|
||||||
openai=(conversation_config.openai if (conversation_config is not None) else None),
|
openai=(conversation_config.openai if (conversation_config is not None) else None),
|
||||||
offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(),
|
offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(),
|
||||||
|
max_prompt_size=conversation_config.max_prompt_size if conversation_config else None,
|
||||||
|
tokenizer=conversation_config.tokenizer if conversation_config else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -211,6 +211,14 @@
|
||||||
grid-gap: 24px;
|
grid-gap: 24px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
button#logout {
|
||||||
|
font-size: 16px;
|
||||||
|
cursor: pointer;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
height: 32px;
|
||||||
|
}
|
||||||
|
|
||||||
@media screen and (max-width: 600px) {
|
@media screen and (max-width: 600px) {
|
||||||
.section-cards {
|
.section-cards {
|
||||||
grid-template-columns: 1fr;
|
grid-template-columns: 1fr;
|
||||||
|
@ -242,6 +250,10 @@
|
||||||
width: 320px;
|
width: 320px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
div.khoj-header-wrapper{
|
||||||
|
grid-template-columns: auto;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
</html>
|
</html>
|
||||||
|
|
|
@ -297,6 +297,11 @@
|
||||||
<div class="finalize-buttons">
|
<div class="finalize-buttons">
|
||||||
<button id="reinitialize" type="submit" title="Regenerate index from scratch">🔄 Reinitialize</button>
|
<button id="reinitialize" type="submit" title="Regenerate index from scratch">🔄 Reinitialize</button>
|
||||||
</div>
|
</div>
|
||||||
|
{% if anonymous_mode == False %}
|
||||||
|
<div class="finalize-buttons">
|
||||||
|
<button id="logout" class="logout" onclick="window.location.href='/auth/logout'">Logout</button>
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
197
src/khoj/interface/web/login.html
Normal file
197
src/khoj/interface/web/login.html
Normal file
|
@ -0,0 +1,197 @@
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
|
||||||
|
<title>Khoj - Search</title>
|
||||||
|
|
||||||
|
<link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png">
|
||||||
|
<link rel="manifest" href="/static/khoj.webmanifest">
|
||||||
|
<link rel="stylesheet" href="/static/assets/khoj.css">
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
{% if demo %}
|
||||||
|
<!-- Banner linking to https://khoj.dev -->
|
||||||
|
<div class="khoj-banner-container">
|
||||||
|
<a class="khoj-banner" href="https://khoj.dev" target="_blank">
|
||||||
|
<p id="khoj-banner" class="khoj-banner">
|
||||||
|
Enroll in Khoj cloud to get your own assistant
|
||||||
|
</p>
|
||||||
|
</a>
|
||||||
|
<input type="text" id="khoj-banner-email" placeholder="email" class="khoj-banner-email"></input>
|
||||||
|
<button id="khoj-banner-submit" class="khoj-banner-button">Submit</button>
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
<!--Add Header Logo and Nav Pane-->
|
||||||
|
<div class="khoj-header">
|
||||||
|
{% if demo %}
|
||||||
|
<a class="khoj-logo" href="https://khoj.dev" target="_blank">
|
||||||
|
<img class="khoj-logo" src="/static/assets/icons/khoj-logo-sideways-500.png" alt="Khoj"></img>
|
||||||
|
</a>
|
||||||
|
{% else %}
|
||||||
|
<a class="khoj-logo" href="/">
|
||||||
|
<img class="khoj-logo" src="/static/assets/icons/khoj-logo-sideways-500.png" alt="Khoj"></img>
|
||||||
|
</a>
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Sign Up button for Google OAuth -->
|
||||||
|
<div id="login-modal">
|
||||||
|
<h1 class="login-modal-title">Become superhuman with your personal knowledge base copilot</h1>
|
||||||
|
<div
|
||||||
|
class="g_id_signin"
|
||||||
|
data-shape="circle"
|
||||||
|
data-text="continue_with"
|
||||||
|
data-logo_alignment="center"
|
||||||
|
data-size="large"
|
||||||
|
data-width="300"
|
||||||
|
data-type="standard">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div id="g_id_onload"
|
||||||
|
data-client_id="{{ google_client_id }}"
|
||||||
|
data-ux_mode="redirect"
|
||||||
|
data-login_uri="{{ redirect_uri }}">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<style>
|
||||||
|
@media only screen and (max-width: 600px) {
|
||||||
|
body {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: 1fr;
|
||||||
|
grid-template-rows: 1fr auto auto auto minmax(80px, 100%);
|
||||||
|
font-size: small!important;
|
||||||
|
}
|
||||||
|
body > * {
|
||||||
|
grid-column: 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@media only screen and (min-width: 600px) {
|
||||||
|
body {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: 1fr min(70vw, 100%) 1fr;
|
||||||
|
grid-template-rows: 1fr auto auto auto minmax(80px, 100%);
|
||||||
|
padding-top: 60vw;
|
||||||
|
}
|
||||||
|
body > * {
|
||||||
|
grid-column: 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
body {
|
||||||
|
padding: 0px;
|
||||||
|
margin: 0px;
|
||||||
|
background: #fff;
|
||||||
|
color: #475569;
|
||||||
|
font-family: roboto, karma, segoe ui, sans-serif;
|
||||||
|
font-size: 20px;
|
||||||
|
font-weight: 300;
|
||||||
|
line-height: 1.5em;
|
||||||
|
}
|
||||||
|
body > * {
|
||||||
|
padding: 10px;
|
||||||
|
margin: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes gradient {
|
||||||
|
0% {
|
||||||
|
background-position: 0% 50%;
|
||||||
|
}
|
||||||
|
50% {
|
||||||
|
background-position: 100% 50%;
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
background-position: 0% 50%;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
a.khoj-logo {
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
button#khoj-banner-submit,
|
||||||
|
input#khoj-banner-email {
|
||||||
|
padding: 10px;
|
||||||
|
border-radius: 5px;
|
||||||
|
border: 1px solid #475569;
|
||||||
|
background: #f9fafc;
|
||||||
|
}
|
||||||
|
|
||||||
|
button#khoj-banner-submit:hover,
|
||||||
|
input#khoj-banner-email:hover {
|
||||||
|
box-shadow: 0 0 11px #aaa;
|
||||||
|
}
|
||||||
|
|
||||||
|
div#login-modal {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: 1fr;
|
||||||
|
grid-template-rows: 1fr auto auto auto;
|
||||||
|
gap: 10px;
|
||||||
|
padding: 10px;
|
||||||
|
margin: 10px;
|
||||||
|
background: #fff;
|
||||||
|
border-radius: 5px;
|
||||||
|
border: 1px solid #475569;
|
||||||
|
box-shadow: 0 0 11px #aaa;
|
||||||
|
margin-left: 25%;
|
||||||
|
margin-right: 25%;
|
||||||
|
}
|
||||||
|
|
||||||
|
div.g_id_signin {
|
||||||
|
margin: 0 auto;
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
h1.login-modal-title {
|
||||||
|
text-align: center;
|
||||||
|
line-height: 28px;
|
||||||
|
font-size: x-large;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media only screen and (max-width: 600px) {
|
||||||
|
a.khoj-banner {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
p.khoj-banner {
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
div#login-modal {
|
||||||
|
margin-left: 10%;
|
||||||
|
margin-right: 10%;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
</style>
|
||||||
|
<script>
|
||||||
|
var khojBannerSubmit = document.getElementById("khoj-banner-submit");
|
||||||
|
khojBannerSubmit?.addEventListener("click", function(event) {
|
||||||
|
event.preventDefault();
|
||||||
|
var email = document.getElementById("khoj-banner-email").value;
|
||||||
|
fetch("https://app.khoj.dev/beta/users/", {
|
||||||
|
method: "POST",
|
||||||
|
body: JSON.stringify({
|
||||||
|
email: email
|
||||||
|
}),
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
}).then(function(response) {
|
||||||
|
return response.json();
|
||||||
|
}).then(function(data) {
|
||||||
|
console.log(data);
|
||||||
|
if (data.user != null) {
|
||||||
|
document.getElementById("khoj-banner").innerHTML = "Thanks for signing up. We'll be in touch soon! 🚀";
|
||||||
|
document.getElementById("khoj-banner-submit").remove();
|
||||||
|
} else {
|
||||||
|
document.getElementById("khoj-banner").innerHTML = "There was an error signing up. Please contact team@khoj.dev";
|
||||||
|
}
|
||||||
|
}).catch(function(error) {
|
||||||
|
console.log(error);
|
||||||
|
document.getElementById("khoj-banner").innerHTML = "There was an error signing up. Please contact team@khoj.dev";
|
||||||
|
});
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
<script src="https://accounts.google.com/gsi/client" async defer></script>
|
||||||
|
</html>
|
|
@ -14,13 +14,13 @@ warnings.filterwarnings("ignore", message=r"legacy way to download files from th
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
import django
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import schedule
|
|
||||||
import django
|
|
||||||
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from rich.logging import RichHandler
|
from rich.logging import RichHandler
|
||||||
|
import schedule
|
||||||
|
|
||||||
from django.core.asgi import get_asgi_application
|
from django.core.asgi import get_asgi_application
|
||||||
from django.core.management import call_command
|
from django.core.management import call_command
|
||||||
|
|
||||||
|
@ -34,13 +34,6 @@ call_command("migrate", "--noinput")
|
||||||
# Initialize Django Static Files
|
# Initialize Django Static Files
|
||||||
call_command("collectstatic", "--noinput")
|
call_command("collectstatic", "--noinput")
|
||||||
|
|
||||||
# Initialize Django
|
|
||||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
|
|
||||||
django.setup()
|
|
||||||
|
|
||||||
# Initialize Django Database
|
|
||||||
call_command("migrate", "--noinput")
|
|
||||||
|
|
||||||
# Initialize the Application Server
|
# Initialize the Application Server
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
|
@ -127,6 +127,7 @@ if not state.demo:
|
||||||
state.processor_config = configure_processor(state.config.processor)
|
state.processor_config = configure_processor(state.config.processor)
|
||||||
|
|
||||||
@api.get("/config/data", response_model=FullConfig)
|
@api.get("/config/data", response_model=FullConfig)
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
def get_config_data(request: Request):
|
def get_config_data(request: Request):
|
||||||
user = request.user.object if request.user.is_authenticated else None
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
enabled_content = EmbeddingsAdapters.get_unique_file_types(user)
|
enabled_content = EmbeddingsAdapters.get_unique_file_types(user)
|
||||||
|
@ -134,6 +135,7 @@ if not state.demo:
|
||||||
return state.config
|
return state.config
|
||||||
|
|
||||||
@api.post("/config/data")
|
@api.post("/config/data")
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
async def set_config_data(
|
async def set_config_data(
|
||||||
request: Request,
|
request: Request,
|
||||||
updated_config: FullConfig,
|
updated_config: FullConfig,
|
||||||
|
@ -166,7 +168,7 @@ if not state.demo:
|
||||||
return state.config
|
return state.config
|
||||||
|
|
||||||
@api.post("/config/data/content_type/github", status_code=200)
|
@api.post("/config/data/content_type/github", status_code=200)
|
||||||
@requires("authenticated")
|
@requires(["authenticated"], redirect="login_page")
|
||||||
async def set_content_config_github_data(
|
async def set_content_config_github_data(
|
||||||
request: Request,
|
request: Request,
|
||||||
updated_config: Union[GithubContentConfig, None],
|
updated_config: Union[GithubContentConfig, None],
|
||||||
|
@ -193,6 +195,7 @@ if not state.demo:
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
@api.post("/config/data/content_type/notion", status_code=200)
|
@api.post("/config/data/content_type/notion", status_code=200)
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
async def set_content_config_notion_data(
|
async def set_content_config_notion_data(
|
||||||
request: Request,
|
request: Request,
|
||||||
updated_config: Union[NotionContentConfig, None],
|
updated_config: Union[NotionContentConfig, None],
|
||||||
|
@ -218,6 +221,7 @@ if not state.demo:
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
@api.post("/delete/config/data/content_type/{content_type}", status_code=200)
|
@api.post("/delete/config/data/content_type/{content_type}", status_code=200)
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
async def remove_content_config_data(
|
async def remove_content_config_data(
|
||||||
request: Request,
|
request: Request,
|
||||||
content_type: str,
|
content_type: str,
|
||||||
|
@ -272,7 +276,7 @@ if not state.demo:
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
@api.post("/config/data/content_type/{content_type}", status_code=200)
|
@api.post("/config/data/content_type/{content_type}", status_code=200)
|
||||||
# @requires("authenticated")
|
@requires(["authenticated"], redirect="login_page")
|
||||||
async def set_content_config_data(
|
async def set_content_config_data(
|
||||||
request: Request,
|
request: Request,
|
||||||
content_type: str,
|
content_type: str,
|
||||||
|
@ -378,6 +382,7 @@ def get_default_config_data():
|
||||||
|
|
||||||
|
|
||||||
@api.get("/config/types", response_model=List[str])
|
@api.get("/config/types", response_model=List[str])
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
def get_config_types(
|
def get_config_types(
|
||||||
request: Request,
|
request: Request,
|
||||||
):
|
):
|
||||||
|
@ -399,6 +404,7 @@ def get_config_types(
|
||||||
|
|
||||||
|
|
||||||
@api.get("/search", response_model=List[SearchResponse])
|
@api.get("/search", response_model=List[SearchResponse])
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
async def search(
|
async def search(
|
||||||
q: str,
|
q: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
@ -532,6 +538,7 @@ async def search(
|
||||||
|
|
||||||
|
|
||||||
@api.get("/update")
|
@api.get("/update")
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
def update(
|
def update(
|
||||||
request: Request,
|
request: Request,
|
||||||
t: Optional[SearchType] = None,
|
t: Optional[SearchType] = None,
|
||||||
|
@ -577,6 +584,7 @@ def update(
|
||||||
|
|
||||||
|
|
||||||
@api.get("/chat/history")
|
@api.get("/chat/history")
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
def chat_history(
|
def chat_history(
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
|
@ -605,6 +613,7 @@ def chat_history(
|
||||||
|
|
||||||
|
|
||||||
@api.get("/chat/options", response_class=Response)
|
@api.get("/chat/options", response_class=Response)
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
async def chat_options(
|
async def chat_options(
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
|
@ -629,6 +638,7 @@ async def chat_options(
|
||||||
|
|
||||||
|
|
||||||
@api.get("/chat", response_class=Response)
|
@api.get("/chat", response_class=Response)
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
async def chat(
|
async def chat(
|
||||||
request: Request,
|
request: Request,
|
||||||
q: str,
|
q: str,
|
||||||
|
|
|
@ -4,9 +4,12 @@ import os
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from starlette.config import Config
|
from starlette.config import Config
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import HTMLResponse, RedirectResponse
|
from starlette.responses import HTMLResponse, RedirectResponse, Response
|
||||||
from authlib.integrations.starlette_client import OAuth, OAuthError
|
from authlib.integrations.starlette_client import OAuth, OAuthError
|
||||||
|
|
||||||
|
from google.oauth2 import id_token
|
||||||
|
from google.auth.transport import requests as google_requests
|
||||||
|
|
||||||
from database.adapters import get_or_create_user
|
from database.adapters import get_or_create_user
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -24,32 +27,40 @@ else:
|
||||||
oauth.register(name="google", server_metadata_url=CONF_URL, client_kwargs={"scope": "openid email profile"})
|
oauth.register(name="google", server_metadata_url=CONF_URL, client_kwargs={"scope": "openid email profile"})
|
||||||
|
|
||||||
|
|
||||||
@auth_router.get("/")
|
|
||||||
async def homepage(request: Request):
|
|
||||||
user = request.session.get("user")
|
|
||||||
if user:
|
|
||||||
data = json.dumps(user)
|
|
||||||
html = f"<pre>{data}</pre>" '<a href="/logout">logout</a>'
|
|
||||||
return HTMLResponse(html)
|
|
||||||
return HTMLResponse('<a href="/login">login</a>')
|
|
||||||
|
|
||||||
|
|
||||||
@auth_router.get("/login")
|
@auth_router.get("/login")
|
||||||
|
async def login_get(request: Request):
|
||||||
|
redirect_uri = request.url_for("auth")
|
||||||
|
return await oauth.google.authorize_redirect(request, redirect_uri)
|
||||||
|
|
||||||
|
|
||||||
|
@auth_router.post("/login")
|
||||||
async def login(request: Request):
|
async def login(request: Request):
|
||||||
redirect_uri = request.url_for("auth")
|
redirect_uri = request.url_for("auth")
|
||||||
return await oauth.google.authorize_redirect(request, redirect_uri)
|
return await oauth.google.authorize_redirect(request, redirect_uri)
|
||||||
|
|
||||||
|
|
||||||
@auth_router.get("/redirect")
|
@auth_router.post("/redirect")
|
||||||
async def auth(request: Request):
|
async def auth(request: Request):
|
||||||
|
form = await request.form()
|
||||||
|
credential = form.get("credential")
|
||||||
|
|
||||||
|
csrf_token_cookie = request.cookies.get("g_csrf_token")
|
||||||
|
if not csrf_token_cookie:
|
||||||
|
return Response("Missing CSRF token", status_code=400)
|
||||||
|
csrf_token_body = form.get("g_csrf_token")
|
||||||
|
if not csrf_token_body:
|
||||||
|
return Response("Missing CSRF token", status_code=400)
|
||||||
|
if csrf_token_cookie != csrf_token_body:
|
||||||
|
return Response("Invalid CSRF token", status_code=400)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
token = await oauth.google.authorize_access_token(request)
|
idinfo = id_token.verify_oauth2_token(credential, google_requests.Request(), os.environ["GOOGLE_CLIENT_ID"])
|
||||||
except OAuthError as error:
|
except OAuthError as error:
|
||||||
return HTMLResponse(f"<h1>{error.error}</h1>")
|
return HTMLResponse(f"<h1>{error.error}</h1>")
|
||||||
khoj_user = await get_or_create_user(token)
|
khoj_user = await get_or_create_user(idinfo)
|
||||||
user = token.get("userinfo")
|
if khoj_user:
|
||||||
if user:
|
request.session["user"] = dict(idinfo)
|
||||||
request.session["user"] = dict(user)
|
|
||||||
return RedirectResponse(url="/")
|
return RedirectResponse(url="/")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -141,6 +141,7 @@ async def update(
|
||||||
)
|
)
|
||||||
logger.info(f"Finished processing batch indexing request")
|
logger.info(f"Finished processing batch indexing request")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to process batch indexing request: {e}", exc_info=True)
|
||||||
logger.error(
|
logger.error(
|
||||||
f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}",
|
f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
|
@ -156,7 +157,6 @@ async def update(
|
||||||
host=host,
|
host=host,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"📪 Content index updated via API call by {client} client")
|
|
||||||
return Response(content="OK", status_code=200)
|
return Response(content="OK", status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
|
# System Packages
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import HTMLResponse, FileResponse
|
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
|
@ -16,9 +20,7 @@ from khoj.utils.rawconfig import (
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config
|
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config
|
||||||
from database.models import KhojUser, LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
|
from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
# Initialize Router
|
# Initialize Router
|
||||||
|
@ -30,15 +32,41 @@ VALID_TEXT_CONTENT_TYPES = ["org", "markdown", "pdf", "plaintext"]
|
||||||
|
|
||||||
# Create Routes
|
# Create Routes
|
||||||
@web_client.get("/", response_class=FileResponse)
|
@web_client.get("/", response_class=FileResponse)
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
def index(request: Request):
|
def index(request: Request):
|
||||||
return templates.TemplateResponse("index.html", context={"request": request, "demo": state.demo})
|
return templates.TemplateResponse("index.html", context={"request": request, "demo": state.demo})
|
||||||
|
|
||||||
|
|
||||||
|
@web_client.post("/", response_class=FileResponse)
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
|
def index_post(request: Request):
|
||||||
|
return templates.TemplateResponse("index.html", context={"request": request, "demo": state.demo})
|
||||||
|
|
||||||
|
|
||||||
@web_client.get("/chat", response_class=FileResponse)
|
@web_client.get("/chat", response_class=FileResponse)
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
def chat_page(request: Request):
|
def chat_page(request: Request):
|
||||||
return templates.TemplateResponse("chat.html", context={"request": request, "demo": state.demo})
|
return templates.TemplateResponse("chat.html", context={"request": request, "demo": state.demo})
|
||||||
|
|
||||||
|
|
||||||
|
@web_client.get("/login", response_class=FileResponse)
|
||||||
|
def login_page(request: Request):
|
||||||
|
if request.user.is_authenticated:
|
||||||
|
next_url = request.query_params.get("next", "/")
|
||||||
|
return RedirectResponse(url=next_url)
|
||||||
|
google_client_id = os.environ.get("GOOGLE_CLIENT_ID")
|
||||||
|
redirect_uri = request.url_for("auth")
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
"login.html",
|
||||||
|
context={
|
||||||
|
"request": request,
|
||||||
|
"demo": state.demo,
|
||||||
|
"google_client_id": google_client_id,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def map_config_to_object(content_type: str):
|
def map_config_to_object(content_type: str):
|
||||||
if content_type == "org":
|
if content_type == "org":
|
||||||
return LocalOrgConfig
|
return LocalOrgConfig
|
||||||
|
@ -53,6 +81,7 @@ def map_config_to_object(content_type: str):
|
||||||
if not state.demo:
|
if not state.demo:
|
||||||
|
|
||||||
@web_client.get("/config", response_class=HTMLResponse)
|
@web_client.get("/config", response_class=HTMLResponse)
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
def config_page(request: Request):
|
def config_page(request: Request):
|
||||||
user = request.user.object if request.user.is_authenticated else None
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
|
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
|
||||||
|
@ -97,11 +126,12 @@ if not state.demo:
|
||||||
"request": request,
|
"request": request,
|
||||||
"current_config": current_config,
|
"current_config": current_config,
|
||||||
"current_model_state": successfully_configured,
|
"current_model_state": successfully_configured,
|
||||||
|
"anonymous_mode": state.anonymous_mode,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@web_client.get("/config/content_type/github", response_class=HTMLResponse)
|
@web_client.get("/config/content_type/github", response_class=HTMLResponse)
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"], redirect="login_page")
|
||||||
def github_config_page(request: Request):
|
def github_config_page(request: Request):
|
||||||
user = request.user.object if request.user.is_authenticated else None
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
current_github_config = get_user_github_config(user)
|
current_github_config = get_user_github_config(user)
|
||||||
|
@ -130,6 +160,7 @@ if not state.demo:
|
||||||
)
|
)
|
||||||
|
|
||||||
@web_client.get("/config/content_type/notion", response_class=HTMLResponse)
|
@web_client.get("/config/content_type/notion", response_class=HTMLResponse)
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
def notion_config_page(request: Request):
|
def notion_config_page(request: Request):
|
||||||
user = request.user.object if request.user.is_authenticated else None
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
current_notion_config = get_user_notion_config(user)
|
current_notion_config = get_user_notion_config(user)
|
||||||
|
@ -145,6 +176,7 @@ if not state.demo:
|
||||||
)
|
)
|
||||||
|
|
||||||
@web_client.get("/config/content_type/{content_type}", response_class=HTMLResponse)
|
@web_client.get("/config/content_type/{content_type}", response_class=HTMLResponse)
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
def content_config_page(request: Request, content_type: str):
|
def content_config_page(request: Request, content_type: str):
|
||||||
if content_type not in VALID_TEXT_CONTENT_TYPES:
|
if content_type not in VALID_TEXT_CONTENT_TYPES:
|
||||||
return templates.TemplateResponse("config.html", context={"request": request})
|
return templates.TemplateResponse("config.html", context={"request": request})
|
||||||
|
|
|
@ -25,7 +25,6 @@ from khoj.utils.rawconfig import (
|
||||||
OfflineChatProcessorConfig,
|
OfflineChatProcessorConfig,
|
||||||
OpenAIProcessorConfig,
|
OpenAIProcessorConfig,
|
||||||
ProcessorConfig,
|
ProcessorConfig,
|
||||||
TextContentConfig,
|
|
||||||
ImageContentConfig,
|
ImageContentConfig,
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
TextSearchConfig,
|
TextSearchConfig,
|
||||||
|
@ -38,7 +37,6 @@ from database.models import (
|
||||||
LocalOrgConfig,
|
LocalOrgConfig,
|
||||||
LocalMarkdownConfig,
|
LocalMarkdownConfig,
|
||||||
LocalPlaintextConfig,
|
LocalPlaintextConfig,
|
||||||
LocalPdfConfig,
|
|
||||||
GithubConfig,
|
GithubConfig,
|
||||||
KhojUser,
|
KhojUser,
|
||||||
GithubRepoConfig,
|
GithubRepoConfig,
|
||||||
|
@ -95,6 +93,19 @@ def default_user():
|
||||||
return UserFactory()
|
return UserFactory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.fixture
|
||||||
|
def default_user2():
|
||||||
|
if KhojUser.objects.filter(username="default").exists():
|
||||||
|
return KhojUser.objects.get(username="default")
|
||||||
|
|
||||||
|
return UserFactory(
|
||||||
|
username="default",
|
||||||
|
email="default@example.com",
|
||||||
|
password="default",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def search_models(search_config: SearchConfig):
|
def search_models(search_config: SearchConfig):
|
||||||
search_models = SearchModels()
|
search_models = SearchModels()
|
||||||
|
|
|
@ -7,6 +7,7 @@ import pytest
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
import pytest
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_routes, configure_search_types
|
from khoj.configure import configure_routes, configure_search_types
|
||||||
|
@ -115,16 +116,11 @@ def test_get_configured_types_via_api(client, sample_org_data):
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data):
|
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
|
||||||
|
|
||||||
# Act
|
|
||||||
response = client.get(f"/api/config/types")
|
response = client.get(f"/api/config/types")
|
||||||
|
assert response.json() == ["all", "org", "image"]
|
||||||
# Assert
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == ["all", "org", "markdown", "image"]
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@ -150,6 +146,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
|
@ -177,9 +174,9 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_notes_search(client, search_config: SearchConfig, sample_org_data):
|
def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
|
||||||
user_query = quote("How to git install application?")
|
user_query = quote("How to git install application?")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -195,13 +192,14 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data):
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_notes_search_with_only_filters(
|
def test_notes_search_with_only_filters(
|
||||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data, default_user2: KhojUser
|
||||||
):
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
OrgToJsonl,
|
OrgToJsonl,
|
||||||
sample_org_data,
|
sample_org_data,
|
||||||
regenerate=False,
|
regenerate=False,
|
||||||
|
user=default_user2,
|
||||||
)
|
)
|
||||||
user_query = quote('+"Emacs" file:"*.org"')
|
user_query = quote('+"Emacs" file:"*.org"')
|
||||||
|
|
||||||
|
@ -217,9 +215,9 @@ def test_notes_search_with_only_filters(
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_notes_search_with_include_filter(client, sample_org_data):
|
def test_notes_search_with_include_filter(client, sample_org_data, default_user2: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
|
||||||
user_query = quote('How to git install application? +"Emacs"')
|
user_query = quote('How to git install application? +"Emacs"')
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -234,12 +232,13 @@ def test_notes_search_with_include_filter(client, sample_org_data):
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_notes_search_with_exclude_filter(client, sample_org_data):
|
def test_notes_search_with_exclude_filter(client, sample_org_data, default_user2: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
OrgToJsonl,
|
OrgToJsonl,
|
||||||
sample_org_data,
|
sample_org_data,
|
||||||
regenerate=False,
|
regenerate=False,
|
||||||
|
user=default_user2,
|
||||||
)
|
)
|
||||||
user_query = quote('How to git install application? -"clone"')
|
user_query = quote('How to git install application? -"clone"')
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue