mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
[Multi-User Part 2]: Add login pages and gate access to application behind login wall (#503)
- Make most routes conditional on authentication *if anonymous mode is not enabled*. If anonymous mode is enabled, it scaffolds a default user and uses that for all application interactions. - Add a basic login page and add routes for redirecting the user if logged in
This commit is contained in:
parent
216acf545f
commit
a8a82d274a
13 changed files with 327 additions and 59 deletions
|
@ -68,6 +68,8 @@ dependencies = [
|
|||
"httpx == 0.25.0",
|
||||
"pgvector == 0.2.3",
|
||||
"psycopg2-binary == 2.9.9",
|
||||
"google-auth == 2.23.3",
|
||||
"python-multipart == 0.0.6",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
|
|
|
@ -73,8 +73,7 @@ async def create_google_user(token: dict) -> KhojUser:
|
|||
|
||||
|
||||
async def get_user_by_token(token: dict) -> KhojUser:
|
||||
user_info = token.get("userinfo")
|
||||
google_user = await GoogleUser.objects.filter(sub=user_info.get("sub")).select_related("user").afirst()
|
||||
google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
|
||||
if not google_user:
|
||||
return None
|
||||
return google_user.user
|
||||
|
|
|
@ -67,7 +67,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
|
||||
if user:
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
elif not state.anonymous_mode:
|
||||
elif state.anonymous_mode:
|
||||
user = await self.khojuser_manager.filter(username="default").afirst()
|
||||
if user:
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
|
@ -77,11 +77,6 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
|
||||
def initialize_server(config: Optional[FullConfig]):
|
||||
if config is None:
|
||||
logger.error(
|
||||
f"🚨 Exiting as Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config or by editing {state.config_file}."
|
||||
)
|
||||
sys.exit(1)
|
||||
elif config is None:
|
||||
logger.warning(
|
||||
f"🚨 Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config, plugins or by editing {state.config_file}."
|
||||
)
|
||||
|
@ -230,6 +225,8 @@ def configure_conversation_processor(
|
|||
conversation_logfile=conversation_logfile,
|
||||
openai=(conversation_config.openai if (conversation_config is not None) else None),
|
||||
offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(),
|
||||
max_prompt_size=conversation_config.max_prompt_size if conversation_config else None,
|
||||
tokenizer=conversation_config.tokenizer if conversation_config else None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -211,6 +211,14 @@
|
|||
grid-gap: 24px;
|
||||
}
|
||||
|
||||
button#logout {
|
||||
font-size: 16px;
|
||||
cursor: pointer;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
height: 32px;
|
||||
}
|
||||
|
||||
@media screen and (max-width: 600px) {
|
||||
.section-cards {
|
||||
grid-template-columns: 1fr;
|
||||
|
@ -242,6 +250,10 @@
|
|||
width: 320px;
|
||||
}
|
||||
|
||||
div.khoj-header-wrapper{
|
||||
grid-template-columns: auto;
|
||||
}
|
||||
|
||||
}
|
||||
</style>
|
||||
</html>
|
||||
|
|
|
@ -297,6 +297,11 @@
|
|||
<div class="finalize-buttons">
|
||||
<button id="reinitialize" type="submit" title="Regenerate index from scratch">🔄 Reinitialize</button>
|
||||
</div>
|
||||
{% if anonymous_mode == False %}
|
||||
<div class="finalize-buttons">
|
||||
<button id="logout" class="logout" onclick="window.location.href='/auth/logout'">Logout</button>
|
||||
</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
197
src/khoj/interface/web/login.html
Normal file
197
src/khoj/interface/web/login.html
Normal file
|
@ -0,0 +1,197 @@
|
|||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
|
||||
<title>Khoj - Search</title>
|
||||
|
||||
<link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png">
|
||||
<link rel="manifest" href="/static/khoj.webmanifest">
|
||||
<link rel="stylesheet" href="/static/assets/khoj.css">
|
||||
</head>
|
||||
|
||||
<body>
|
||||
{% if demo %}
|
||||
<!-- Banner linking to https://khoj.dev -->
|
||||
<div class="khoj-banner-container">
|
||||
<a class="khoj-banner" href="https://khoj.dev" target="_blank">
|
||||
<p id="khoj-banner" class="khoj-banner">
|
||||
Enroll in Khoj cloud to get your own assistant
|
||||
</p>
|
||||
</a>
|
||||
<input type="text" id="khoj-banner-email" placeholder="email" class="khoj-banner-email"></input>
|
||||
<button id="khoj-banner-submit" class="khoj-banner-button">Submit</button>
|
||||
</div>
|
||||
{% endif %}
|
||||
<!--Add Header Logo and Nav Pane-->
|
||||
<div class="khoj-header">
|
||||
{% if demo %}
|
||||
<a class="khoj-logo" href="https://khoj.dev" target="_blank">
|
||||
<img class="khoj-logo" src="/static/assets/icons/khoj-logo-sideways-500.png" alt="Khoj"></img>
|
||||
</a>
|
||||
{% else %}
|
||||
<a class="khoj-logo" href="/">
|
||||
<img class="khoj-logo" src="/static/assets/icons/khoj-logo-sideways-500.png" alt="Khoj"></img>
|
||||
</a>
|
||||
{% endif %}
|
||||
</div>
|
||||
|
||||
<!-- Sign Up button for Google OAuth -->
|
||||
<div id="login-modal">
|
||||
<h1 class="login-modal-title">Become superhuman with your personal knowledge base copilot</h1>
|
||||
<div
|
||||
class="g_id_signin"
|
||||
data-shape="circle"
|
||||
data-text="continue_with"
|
||||
data-logo_alignment="center"
|
||||
data-size="large"
|
||||
data-width="300"
|
||||
data-type="standard">
|
||||
</div>
|
||||
</div>
|
||||
<div id="g_id_onload"
|
||||
data-client_id="{{ google_client_id }}"
|
||||
data-ux_mode="redirect"
|
||||
data-login_uri="{{ redirect_uri }}">
|
||||
</div>
|
||||
|
||||
|
||||
</body>
|
||||
|
||||
<style>
|
||||
@media only screen and (max-width: 600px) {
|
||||
body {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr;
|
||||
grid-template-rows: 1fr auto auto auto minmax(80px, 100%);
|
||||
font-size: small!important;
|
||||
}
|
||||
body > * {
|
||||
grid-column: 1;
|
||||
}
|
||||
}
|
||||
@media only screen and (min-width: 600px) {
|
||||
body {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr min(70vw, 100%) 1fr;
|
||||
grid-template-rows: 1fr auto auto auto minmax(80px, 100%);
|
||||
padding-top: 60vw;
|
||||
}
|
||||
body > * {
|
||||
grid-column: 2;
|
||||
}
|
||||
}
|
||||
body {
|
||||
padding: 0px;
|
||||
margin: 0px;
|
||||
background: #fff;
|
||||
color: #475569;
|
||||
font-family: roboto, karma, segoe ui, sans-serif;
|
||||
font-size: 20px;
|
||||
font-weight: 300;
|
||||
line-height: 1.5em;
|
||||
}
|
||||
body > * {
|
||||
padding: 10px;
|
||||
margin: 10px;
|
||||
}
|
||||
|
||||
@keyframes gradient {
|
||||
0% {
|
||||
background-position: 0% 50%;
|
||||
}
|
||||
50% {
|
||||
background-position: 100% 50%;
|
||||
}
|
||||
100% {
|
||||
background-position: 0% 50%;
|
||||
}
|
||||
}
|
||||
|
||||
a.khoj-logo {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
button#khoj-banner-submit,
|
||||
input#khoj-banner-email {
|
||||
padding: 10px;
|
||||
border-radius: 5px;
|
||||
border: 1px solid #475569;
|
||||
background: #f9fafc;
|
||||
}
|
||||
|
||||
button#khoj-banner-submit:hover,
|
||||
input#khoj-banner-email:hover {
|
||||
box-shadow: 0 0 11px #aaa;
|
||||
}
|
||||
|
||||
div#login-modal {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr;
|
||||
grid-template-rows: 1fr auto auto auto;
|
||||
gap: 10px;
|
||||
padding: 10px;
|
||||
margin: 10px;
|
||||
background: #fff;
|
||||
border-radius: 5px;
|
||||
border: 1px solid #475569;
|
||||
box-shadow: 0 0 11px #aaa;
|
||||
margin-left: 25%;
|
||||
margin-right: 25%;
|
||||
}
|
||||
|
||||
div.g_id_signin {
|
||||
margin: 0 auto;
|
||||
display: block;
|
||||
}
|
||||
|
||||
h1.login-modal-title {
|
||||
text-align: center;
|
||||
line-height: 28px;
|
||||
font-size: x-large;
|
||||
}
|
||||
|
||||
@media only screen and (max-width: 600px) {
|
||||
a.khoj-banner {
|
||||
display: block;
|
||||
}
|
||||
p.khoj-banner {
|
||||
padding: 0;
|
||||
}
|
||||
div#login-modal {
|
||||
margin-left: 10%;
|
||||
margin-right: 10%;
|
||||
}
|
||||
}
|
||||
|
||||
</style>
|
||||
<script>
|
||||
var khojBannerSubmit = document.getElementById("khoj-banner-submit");
|
||||
khojBannerSubmit?.addEventListener("click", function(event) {
|
||||
event.preventDefault();
|
||||
var email = document.getElementById("khoj-banner-email").value;
|
||||
fetch("https://app.khoj.dev/beta/users/", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
email: email
|
||||
}),
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}).then(function(response) {
|
||||
return response.json();
|
||||
}).then(function(data) {
|
||||
console.log(data);
|
||||
if (data.user != null) {
|
||||
document.getElementById("khoj-banner").innerHTML = "Thanks for signing up. We'll be in touch soon! 🚀";
|
||||
document.getElementById("khoj-banner-submit").remove();
|
||||
} else {
|
||||
document.getElementById("khoj-banner").innerHTML = "There was an error signing up. Please contact team@khoj.dev";
|
||||
}
|
||||
}).catch(function(error) {
|
||||
console.log(error);
|
||||
document.getElementById("khoj-banner").innerHTML = "There was an error signing up. Please contact team@khoj.dev";
|
||||
});
|
||||
});
|
||||
</script>
|
||||
<script src="https://accounts.google.com/gsi/client" async defer></script>
|
||||
</html>
|
|
@ -14,13 +14,13 @@ warnings.filterwarnings("ignore", message=r"legacy way to download files from th
|
|||
|
||||
# External Packages
|
||||
import uvicorn
|
||||
import django
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import schedule
|
||||
import django
|
||||
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from rich.logging import RichHandler
|
||||
import schedule
|
||||
|
||||
from django.core.asgi import get_asgi_application
|
||||
from django.core.management import call_command
|
||||
|
||||
|
@ -34,13 +34,6 @@ call_command("migrate", "--noinput")
|
|||
# Initialize Django Static Files
|
||||
call_command("collectstatic", "--noinput")
|
||||
|
||||
# Initialize Django
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
|
||||
django.setup()
|
||||
|
||||
# Initialize Django Database
|
||||
call_command("migrate", "--noinput")
|
||||
|
||||
# Initialize the Application Server
|
||||
app = FastAPI()
|
||||
|
||||
|
|
|
@ -127,6 +127,7 @@ if not state.demo:
|
|||
state.processor_config = configure_processor(state.config.processor)
|
||||
|
||||
@api.get("/config/data", response_model=FullConfig)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def get_config_data(request: Request):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
enabled_content = EmbeddingsAdapters.get_unique_file_types(user)
|
||||
|
@ -134,6 +135,7 @@ if not state.demo:
|
|||
return state.config
|
||||
|
||||
@api.post("/config/data")
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
async def set_config_data(
|
||||
request: Request,
|
||||
updated_config: FullConfig,
|
||||
|
@ -166,7 +168,7 @@ if not state.demo:
|
|||
return state.config
|
||||
|
||||
@api.post("/config/data/content_type/github", status_code=200)
|
||||
@requires("authenticated")
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
async def set_content_config_github_data(
|
||||
request: Request,
|
||||
updated_config: Union[GithubContentConfig, None],
|
||||
|
@ -193,6 +195,7 @@ if not state.demo:
|
|||
return {"status": "ok"}
|
||||
|
||||
@api.post("/config/data/content_type/notion", status_code=200)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
async def set_content_config_notion_data(
|
||||
request: Request,
|
||||
updated_config: Union[NotionContentConfig, None],
|
||||
|
@ -218,6 +221,7 @@ if not state.demo:
|
|||
return {"status": "ok"}
|
||||
|
||||
@api.post("/delete/config/data/content_type/{content_type}", status_code=200)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
async def remove_content_config_data(
|
||||
request: Request,
|
||||
content_type: str,
|
||||
|
@ -272,7 +276,7 @@ if not state.demo:
|
|||
return {"status": "error", "message": str(e)}
|
||||
|
||||
@api.post("/config/data/content_type/{content_type}", status_code=200)
|
||||
# @requires("authenticated")
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
async def set_content_config_data(
|
||||
request: Request,
|
||||
content_type: str,
|
||||
|
@ -378,6 +382,7 @@ def get_default_config_data():
|
|||
|
||||
|
||||
@api.get("/config/types", response_model=List[str])
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def get_config_types(
|
||||
request: Request,
|
||||
):
|
||||
|
@ -399,6 +404,7 @@ def get_config_types(
|
|||
|
||||
|
||||
@api.get("/search", response_model=List[SearchResponse])
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
async def search(
|
||||
q: str,
|
||||
request: Request,
|
||||
|
@ -532,6 +538,7 @@ async def search(
|
|||
|
||||
|
||||
@api.get("/update")
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def update(
|
||||
request: Request,
|
||||
t: Optional[SearchType] = None,
|
||||
|
@ -577,6 +584,7 @@ def update(
|
|||
|
||||
|
||||
@api.get("/chat/history")
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def chat_history(
|
||||
request: Request,
|
||||
client: Optional[str] = None,
|
||||
|
@ -605,6 +613,7 @@ def chat_history(
|
|||
|
||||
|
||||
@api.get("/chat/options", response_class=Response)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
async def chat_options(
|
||||
request: Request,
|
||||
client: Optional[str] = None,
|
||||
|
@ -629,6 +638,7 @@ async def chat_options(
|
|||
|
||||
|
||||
@api.get("/chat", response_class=Response)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
async def chat(
|
||||
request: Request,
|
||||
q: str,
|
||||
|
|
|
@ -4,9 +4,12 @@ import os
|
|||
from fastapi import APIRouter
|
||||
from starlette.config import Config
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, RedirectResponse
|
||||
from starlette.responses import HTMLResponse, RedirectResponse, Response
|
||||
from authlib.integrations.starlette_client import OAuth, OAuthError
|
||||
|
||||
from google.oauth2 import id_token
|
||||
from google.auth.transport import requests as google_requests
|
||||
|
||||
from database.adapters import get_or_create_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -24,32 +27,40 @@ else:
|
|||
oauth.register(name="google", server_metadata_url=CONF_URL, client_kwargs={"scope": "openid email profile"})
|
||||
|
||||
|
||||
@auth_router.get("/")
|
||||
async def homepage(request: Request):
|
||||
user = request.session.get("user")
|
||||
if user:
|
||||
data = json.dumps(user)
|
||||
html = f"<pre>{data}</pre>" '<a href="/logout">logout</a>'
|
||||
return HTMLResponse(html)
|
||||
return HTMLResponse('<a href="/login">login</a>')
|
||||
|
||||
|
||||
@auth_router.get("/login")
|
||||
async def login_get(request: Request):
|
||||
redirect_uri = request.url_for("auth")
|
||||
return await oauth.google.authorize_redirect(request, redirect_uri)
|
||||
|
||||
|
||||
@auth_router.post("/login")
|
||||
async def login(request: Request):
|
||||
redirect_uri = request.url_for("auth")
|
||||
return await oauth.google.authorize_redirect(request, redirect_uri)
|
||||
|
||||
|
||||
@auth_router.get("/redirect")
|
||||
@auth_router.post("/redirect")
|
||||
async def auth(request: Request):
|
||||
form = await request.form()
|
||||
credential = form.get("credential")
|
||||
|
||||
csrf_token_cookie = request.cookies.get("g_csrf_token")
|
||||
if not csrf_token_cookie:
|
||||
return Response("Missing CSRF token", status_code=400)
|
||||
csrf_token_body = form.get("g_csrf_token")
|
||||
if not csrf_token_body:
|
||||
return Response("Missing CSRF token", status_code=400)
|
||||
if csrf_token_cookie != csrf_token_body:
|
||||
return Response("Invalid CSRF token", status_code=400)
|
||||
|
||||
try:
|
||||
token = await oauth.google.authorize_access_token(request)
|
||||
idinfo = id_token.verify_oauth2_token(credential, google_requests.Request(), os.environ["GOOGLE_CLIENT_ID"])
|
||||
except OAuthError as error:
|
||||
return HTMLResponse(f"<h1>{error.error}</h1>")
|
||||
khoj_user = await get_or_create_user(token)
|
||||
user = token.get("userinfo")
|
||||
if user:
|
||||
request.session["user"] = dict(user)
|
||||
khoj_user = await get_or_create_user(idinfo)
|
||||
if khoj_user:
|
||||
request.session["user"] = dict(idinfo)
|
||||
|
||||
return RedirectResponse(url="/")
|
||||
|
||||
|
||||
|
|
|
@ -141,6 +141,7 @@ async def update(
|
|||
)
|
||||
logger.info(f"Finished processing batch indexing request")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process batch indexing request: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}",
|
||||
exc_info=True,
|
||||
|
@ -156,7 +157,6 @@ async def update(
|
|||
host=host,
|
||||
)
|
||||
|
||||
logger.info(f"📪 Content index updated via API call by {client} client")
|
||||
return Response(content="OK", status_code=200)
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
# System Packages
|
||||
import json
|
||||
import os
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Request
|
||||
from fastapi.responses import HTMLResponse, FileResponse
|
||||
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from starlette.authentication import requires
|
||||
from khoj.utils.rawconfig import (
|
||||
|
@ -16,9 +20,7 @@ from khoj.utils.rawconfig import (
|
|||
# Internal Packages
|
||||
from khoj.utils import constants, state
|
||||
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config
|
||||
from database.models import KhojUser, LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
|
||||
|
||||
import json
|
||||
from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
|
||||
|
||||
|
||||
# Initialize Router
|
||||
|
@ -30,15 +32,41 @@ VALID_TEXT_CONTENT_TYPES = ["org", "markdown", "pdf", "plaintext"]
|
|||
|
||||
# Create Routes
|
||||
@web_client.get("/", response_class=FileResponse)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def index(request: Request):
|
||||
return templates.TemplateResponse("index.html", context={"request": request, "demo": state.demo})
|
||||
|
||||
|
||||
@web_client.post("/", response_class=FileResponse)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def index_post(request: Request):
|
||||
return templates.TemplateResponse("index.html", context={"request": request, "demo": state.demo})
|
||||
|
||||
|
||||
@web_client.get("/chat", response_class=FileResponse)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def chat_page(request: Request):
|
||||
return templates.TemplateResponse("chat.html", context={"request": request, "demo": state.demo})
|
||||
|
||||
|
||||
@web_client.get("/login", response_class=FileResponse)
|
||||
def login_page(request: Request):
|
||||
if request.user.is_authenticated:
|
||||
next_url = request.query_params.get("next", "/")
|
||||
return RedirectResponse(url=next_url)
|
||||
google_client_id = os.environ.get("GOOGLE_CLIENT_ID")
|
||||
redirect_uri = request.url_for("auth")
|
||||
return templates.TemplateResponse(
|
||||
"login.html",
|
||||
context={
|
||||
"request": request,
|
||||
"demo": state.demo,
|
||||
"google_client_id": google_client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def map_config_to_object(content_type: str):
|
||||
if content_type == "org":
|
||||
return LocalOrgConfig
|
||||
|
@ -53,6 +81,7 @@ def map_config_to_object(content_type: str):
|
|||
if not state.demo:
|
||||
|
||||
@web_client.get("/config", response_class=HTMLResponse)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def config_page(request: Request):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
|
||||
|
@ -97,11 +126,12 @@ if not state.demo:
|
|||
"request": request,
|
||||
"current_config": current_config,
|
||||
"current_model_state": successfully_configured,
|
||||
"anonymous_mode": state.anonymous_mode,
|
||||
},
|
||||
)
|
||||
|
||||
@web_client.get("/config/content_type/github", response_class=HTMLResponse)
|
||||
@requires(["authenticated"])
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def github_config_page(request: Request):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
current_github_config = get_user_github_config(user)
|
||||
|
@ -130,6 +160,7 @@ if not state.demo:
|
|||
)
|
||||
|
||||
@web_client.get("/config/content_type/notion", response_class=HTMLResponse)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def notion_config_page(request: Request):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
current_notion_config = get_user_notion_config(user)
|
||||
|
@ -145,6 +176,7 @@ if not state.demo:
|
|||
)
|
||||
|
||||
@web_client.get("/config/content_type/{content_type}", response_class=HTMLResponse)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def content_config_page(request: Request, content_type: str):
|
||||
if content_type not in VALID_TEXT_CONTENT_TYPES:
|
||||
return templates.TemplateResponse("config.html", context={"request": request})
|
||||
|
|
|
@ -25,7 +25,6 @@ from khoj.utils.rawconfig import (
|
|||
OfflineChatProcessorConfig,
|
||||
OpenAIProcessorConfig,
|
||||
ProcessorConfig,
|
||||
TextContentConfig,
|
||||
ImageContentConfig,
|
||||
SearchConfig,
|
||||
TextSearchConfig,
|
||||
|
@ -38,7 +37,6 @@ from database.models import (
|
|||
LocalOrgConfig,
|
||||
LocalMarkdownConfig,
|
||||
LocalPlaintextConfig,
|
||||
LocalPdfConfig,
|
||||
GithubConfig,
|
||||
KhojUser,
|
||||
GithubRepoConfig,
|
||||
|
@ -95,6 +93,19 @@ def default_user():
|
|||
return UserFactory()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def default_user2():
|
||||
if KhojUser.objects.filter(username="default").exists():
|
||||
return KhojUser.objects.get(username="default")
|
||||
|
||||
return UserFactory(
|
||||
username="default",
|
||||
email="default@example.com",
|
||||
password="default",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_models(search_config: SearchConfig):
|
||||
search_models = SearchModels()
|
||||
|
|
|
@ -7,6 +7,7 @@ import pytest
|
|||
# External Packages
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_routes, configure_search_types
|
||||
|
@ -115,16 +116,11 @@ def test_get_configured_types_via_api(client, sample_org_data):
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data):
|
||||
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser):
|
||||
# Arrange
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||
|
||||
# Act
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
|
||||
response = client.get(f"/api/config/types")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["all", "org", "markdown", "image"]
|
||||
assert response.json() == ["all", "org", "image"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -150,6 +146,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
@ -177,9 +174,9 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search(client, search_config: SearchConfig, sample_org_data):
|
||||
def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser):
|
||||
# Arrange
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
|
||||
user_query = quote("How to git install application?")
|
||||
|
||||
# Act
|
||||
|
@ -195,13 +192,14 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data):
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_with_only_filters(
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data, default_user2: KhojUser
|
||||
):
|
||||
# Arrange
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
sample_org_data,
|
||||
regenerate=False,
|
||||
user=default_user2,
|
||||
)
|
||||
user_query = quote('+"Emacs" file:"*.org"')
|
||||
|
||||
|
@ -217,9 +215,9 @@ def test_notes_search_with_only_filters(
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_with_include_filter(client, sample_org_data):
|
||||
def test_notes_search_with_include_filter(client, sample_org_data, default_user2: KhojUser):
|
||||
# Arrange
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
|
||||
user_query = quote('How to git install application? +"Emacs"')
|
||||
|
||||
# Act
|
||||
|
@ -234,12 +232,13 @@ def test_notes_search_with_include_filter(client, sample_org_data):
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_with_exclude_filter(client, sample_org_data):
|
||||
def test_notes_search_with_exclude_filter(client, sample_org_data, default_user2: KhojUser):
|
||||
# Arrange
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
sample_org_data,
|
||||
regenerate=False,
|
||||
user=default_user2,
|
||||
)
|
||||
user_query = quote('How to git install application? -"clone"')
|
||||
|
||||
|
|
Loading…
Reference in a new issue