[Multi-User Part 2]: Add login pages and gate access to application behind login wall (#503)

- Make most routes conditional on authentication *if anonymous mode is not enabled*. If anonymous mode is enabled, it scaffolds a default user and uses that for all application interactions.
- Add a basic login page and add routes for redirecting the user if logged in
This commit is contained in:
sabaimran 2023-10-26 10:17:29 -07:00 committed by GitHub
parent 216acf545f
commit a8a82d274a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 327 additions and 59 deletions

View file

@ -68,6 +68,8 @@ dependencies = [
"httpx == 0.25.0", "httpx == 0.25.0",
"pgvector == 0.2.3", "pgvector == 0.2.3",
"psycopg2-binary == 2.9.9", "psycopg2-binary == 2.9.9",
"google-auth == 2.23.3",
"python-multipart == 0.0.6",
] ]
dynamic = ["version"] dynamic = ["version"]

View file

@ -73,8 +73,7 @@ async def create_google_user(token: dict) -> KhojUser:
async def get_user_by_token(token: dict) -> KhojUser: async def get_user_by_token(token: dict) -> KhojUser:
user_info = token.get("userinfo") google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
google_user = await GoogleUser.objects.filter(sub=user_info.get("sub")).select_related("user").afirst()
if not google_user: if not google_user:
return None return None
return google_user.user return google_user.user

View file

@ -67,7 +67,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst() user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
if user: if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
elif not state.anonymous_mode: elif state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").afirst() user = await self.khojuser_manager.filter(username="default").afirst()
if user: if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
@ -77,11 +77,6 @@ class UserAuthenticationBackend(AuthenticationBackend):
def initialize_server(config: Optional[FullConfig]): def initialize_server(config: Optional[FullConfig]):
if config is None: if config is None:
logger.error(
f"🚨 Exiting as Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config or by editing {state.config_file}."
)
sys.exit(1)
elif config is None:
logger.warning( logger.warning(
f"🚨 Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config, plugins or by editing {state.config_file}." f"🚨 Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config, plugins or by editing {state.config_file}."
) )
@ -230,6 +225,8 @@ def configure_conversation_processor(
conversation_logfile=conversation_logfile, conversation_logfile=conversation_logfile,
openai=(conversation_config.openai if (conversation_config is not None) else None), openai=(conversation_config.openai if (conversation_config is not None) else None),
offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(), offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(),
max_prompt_size=conversation_config.max_prompt_size if conversation_config else None,
tokenizer=conversation_config.tokenizer if conversation_config else None,
) )
) )
else: else:

View file

@ -211,6 +211,14 @@
grid-gap: 24px; grid-gap: 24px;
} }
button#logout {
font-size: 16px;
cursor: pointer;
margin: 0;
padding: 0;
height: 32px;
}
@media screen and (max-width: 600px) { @media screen and (max-width: 600px) {
.section-cards { .section-cards {
grid-template-columns: 1fr; grid-template-columns: 1fr;
@ -242,6 +250,10 @@
width: 320px; width: 320px;
} }
div.khoj-header-wrapper{
grid-template-columns: auto;
}
} }
</style> </style>
</html> </html>

View file

@ -297,6 +297,11 @@
<div class="finalize-buttons"> <div class="finalize-buttons">
<button id="reinitialize" type="submit" title="Regenerate index from scratch">🔄 Reinitialize</button> <button id="reinitialize" type="submit" title="Regenerate index from scratch">🔄 Reinitialize</button>
</div> </div>
{% if anonymous_mode == False %}
<div class="finalize-buttons">
<button id="logout" class="logout" onclick="window.location.href='/auth/logout'">Logout</button>
</div>
{% endif %}
</div> </div>
</div> </div>
</div> </div>

View file

@ -0,0 +1,197 @@
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
<title>Khoj - Search</title>
<link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png">
<link rel="manifest" href="/static/khoj.webmanifest">
<link rel="stylesheet" href="/static/assets/khoj.css">
</head>
<body>
{% if demo %}
<!-- Banner linking to https://khoj.dev -->
<div class="khoj-banner-container">
<a class="khoj-banner" href="https://khoj.dev" target="_blank">
<p id="khoj-banner" class="khoj-banner">
Enroll in Khoj cloud to get your own assistant
</p>
</a>
<input type="text" id="khoj-banner-email" placeholder="email" class="khoj-banner-email"></input>
<button id="khoj-banner-submit" class="khoj-banner-button">Submit</button>
</div>
{% endif %}
<!--Add Header Logo and Nav Pane-->
<div class="khoj-header">
{% if demo %}
<a class="khoj-logo" href="https://khoj.dev" target="_blank">
<img class="khoj-logo" src="/static/assets/icons/khoj-logo-sideways-500.png" alt="Khoj"></img>
</a>
{% else %}
<a class="khoj-logo" href="/">
<img class="khoj-logo" src="/static/assets/icons/khoj-logo-sideways-500.png" alt="Khoj"></img>
</a>
{% endif %}
</div>
<!-- Sign Up button for Google OAuth -->
<div id="login-modal">
<h1 class="login-modal-title">Become superhuman with your personal knowledge base copilot</h1>
<div
class="g_id_signin"
data-shape="circle"
data-text="continue_with"
data-logo_alignment="center"
data-size="large"
data-width="300"
data-type="standard">
</div>
</div>
<div id="g_id_onload"
data-client_id="{{ google_client_id }}"
data-ux_mode="redirect"
data-login_uri="{{ redirect_uri }}">
</div>
</body>
<style>
@media only screen and (max-width: 600px) {
body {
display: grid;
grid-template-columns: 1fr;
grid-template-rows: 1fr auto auto auto minmax(80px, 100%);
font-size: small!important;
}
body > * {
grid-column: 1;
}
}
@media only screen and (min-width: 600px) {
body {
display: grid;
grid-template-columns: 1fr min(70vw, 100%) 1fr;
grid-template-rows: 1fr auto auto auto minmax(80px, 100%);
padding-top: 60vw;
}
body > * {
grid-column: 2;
}
}
body {
padding: 0px;
margin: 0px;
background: #fff;
color: #475569;
font-family: roboto, karma, segoe ui, sans-serif;
font-size: 20px;
font-weight: 300;
line-height: 1.5em;
}
body > * {
padding: 10px;
margin: 10px;
}
@keyframes gradient {
0% {
background-position: 0% 50%;
}
50% {
background-position: 100% 50%;
}
100% {
background-position: 0% 50%;
}
}
a.khoj-logo {
text-align: center;
}
button#khoj-banner-submit,
input#khoj-banner-email {
padding: 10px;
border-radius: 5px;
border: 1px solid #475569;
background: #f9fafc;
}
button#khoj-banner-submit:hover,
input#khoj-banner-email:hover {
box-shadow: 0 0 11px #aaa;
}
div#login-modal {
display: grid;
grid-template-columns: 1fr;
grid-template-rows: 1fr auto auto auto;
gap: 10px;
padding: 10px;
margin: 10px;
background: #fff;
border-radius: 5px;
border: 1px solid #475569;
box-shadow: 0 0 11px #aaa;
margin-left: 25%;
margin-right: 25%;
}
div.g_id_signin {
margin: 0 auto;
display: block;
}
h1.login-modal-title {
text-align: center;
line-height: 28px;
font-size: x-large;
}
@media only screen and (max-width: 600px) {
a.khoj-banner {
display: block;
}
p.khoj-banner {
padding: 0;
}
div#login-modal {
margin-left: 10%;
margin-right: 10%;
}
}
</style>
<script>
var khojBannerSubmit = document.getElementById("khoj-banner-submit");
khojBannerSubmit?.addEventListener("click", function(event) {
event.preventDefault();
var email = document.getElementById("khoj-banner-email").value;
fetch("https://app.khoj.dev/beta/users/", {
method: "POST",
body: JSON.stringify({
email: email
}),
headers: {
"Content-Type": "application/json"
}
}).then(function(response) {
return response.json();
}).then(function(data) {
console.log(data);
if (data.user != null) {
document.getElementById("khoj-banner").innerHTML = "Thanks for signing up. We'll be in touch soon! 🚀";
document.getElementById("khoj-banner-submit").remove();
} else {
document.getElementById("khoj-banner").innerHTML = "There was an error signing up. Please contact team@khoj.dev";
}
}).catch(function(error) {
console.log(error);
document.getElementById("khoj-banner").innerHTML = "There was an error signing up. Please contact team@khoj.dev";
});
});
</script>
<script src="https://accounts.google.com/gsi/client" async defer></script>
</html>

View file

@ -14,13 +14,13 @@ warnings.filterwarnings("ignore", message=r"legacy way to download files from th
# External Packages # External Packages
import uvicorn import uvicorn
import django
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import schedule
import django
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from rich.logging import RichHandler from rich.logging import RichHandler
import schedule
from django.core.asgi import get_asgi_application from django.core.asgi import get_asgi_application
from django.core.management import call_command from django.core.management import call_command
@ -34,13 +34,6 @@ call_command("migrate", "--noinput")
# Initialize Django Static Files # Initialize Django Static Files
call_command("collectstatic", "--noinput") call_command("collectstatic", "--noinput")
# Initialize Django
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
django.setup()
# Initialize Django Database
call_command("migrate", "--noinput")
# Initialize the Application Server # Initialize the Application Server
app = FastAPI() app = FastAPI()

View file

@ -127,6 +127,7 @@ if not state.demo:
state.processor_config = configure_processor(state.config.processor) state.processor_config = configure_processor(state.config.processor)
@api.get("/config/data", response_model=FullConfig) @api.get("/config/data", response_model=FullConfig)
@requires(["authenticated"], redirect="login_page")
def get_config_data(request: Request): def get_config_data(request: Request):
user = request.user.object if request.user.is_authenticated else None user = request.user.object if request.user.is_authenticated else None
enabled_content = EmbeddingsAdapters.get_unique_file_types(user) enabled_content = EmbeddingsAdapters.get_unique_file_types(user)
@ -134,6 +135,7 @@ if not state.demo:
return state.config return state.config
@api.post("/config/data") @api.post("/config/data")
@requires(["authenticated"], redirect="login_page")
async def set_config_data( async def set_config_data(
request: Request, request: Request,
updated_config: FullConfig, updated_config: FullConfig,
@ -166,7 +168,7 @@ if not state.demo:
return state.config return state.config
@api.post("/config/data/content_type/github", status_code=200) @api.post("/config/data/content_type/github", status_code=200)
@requires("authenticated") @requires(["authenticated"], redirect="login_page")
async def set_content_config_github_data( async def set_content_config_github_data(
request: Request, request: Request,
updated_config: Union[GithubContentConfig, None], updated_config: Union[GithubContentConfig, None],
@ -193,6 +195,7 @@ if not state.demo:
return {"status": "ok"} return {"status": "ok"}
@api.post("/config/data/content_type/notion", status_code=200) @api.post("/config/data/content_type/notion", status_code=200)
@requires(["authenticated"], redirect="login_page")
async def set_content_config_notion_data( async def set_content_config_notion_data(
request: Request, request: Request,
updated_config: Union[NotionContentConfig, None], updated_config: Union[NotionContentConfig, None],
@ -218,6 +221,7 @@ if not state.demo:
return {"status": "ok"} return {"status": "ok"}
@api.post("/delete/config/data/content_type/{content_type}", status_code=200) @api.post("/delete/config/data/content_type/{content_type}", status_code=200)
@requires(["authenticated"], redirect="login_page")
async def remove_content_config_data( async def remove_content_config_data(
request: Request, request: Request,
content_type: str, content_type: str,
@ -272,7 +276,7 @@ if not state.demo:
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@api.post("/config/data/content_type/{content_type}", status_code=200) @api.post("/config/data/content_type/{content_type}", status_code=200)
# @requires("authenticated") @requires(["authenticated"], redirect="login_page")
async def set_content_config_data( async def set_content_config_data(
request: Request, request: Request,
content_type: str, content_type: str,
@ -378,6 +382,7 @@ def get_default_config_data():
@api.get("/config/types", response_model=List[str]) @api.get("/config/types", response_model=List[str])
@requires(["authenticated"], redirect="login_page")
def get_config_types( def get_config_types(
request: Request, request: Request,
): ):
@ -399,6 +404,7 @@ def get_config_types(
@api.get("/search", response_model=List[SearchResponse]) @api.get("/search", response_model=List[SearchResponse])
@requires(["authenticated"], redirect="login_page")
async def search( async def search(
q: str, q: str,
request: Request, request: Request,
@ -532,6 +538,7 @@ async def search(
@api.get("/update") @api.get("/update")
@requires(["authenticated"], redirect="login_page")
def update( def update(
request: Request, request: Request,
t: Optional[SearchType] = None, t: Optional[SearchType] = None,
@ -577,6 +584,7 @@ def update(
@api.get("/chat/history") @api.get("/chat/history")
@requires(["authenticated"], redirect="login_page")
def chat_history( def chat_history(
request: Request, request: Request,
client: Optional[str] = None, client: Optional[str] = None,
@ -605,6 +613,7 @@ def chat_history(
@api.get("/chat/options", response_class=Response) @api.get("/chat/options", response_class=Response)
@requires(["authenticated"], redirect="login_page")
async def chat_options( async def chat_options(
request: Request, request: Request,
client: Optional[str] = None, client: Optional[str] = None,
@ -629,6 +638,7 @@ async def chat_options(
@api.get("/chat", response_class=Response) @api.get("/chat", response_class=Response)
@requires(["authenticated"], redirect="login_page")
async def chat( async def chat(
request: Request, request: Request,
q: str, q: str,

View file

@ -4,9 +4,12 @@ import os
from fastapi import APIRouter from fastapi import APIRouter
from starlette.config import Config from starlette.config import Config
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import HTMLResponse, RedirectResponse from starlette.responses import HTMLResponse, RedirectResponse, Response
from authlib.integrations.starlette_client import OAuth, OAuthError from authlib.integrations.starlette_client import OAuth, OAuthError
from google.oauth2 import id_token
from google.auth.transport import requests as google_requests
from database.adapters import get_or_create_user from database.adapters import get_or_create_user
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,32 +27,40 @@ else:
oauth.register(name="google", server_metadata_url=CONF_URL, client_kwargs={"scope": "openid email profile"}) oauth.register(name="google", server_metadata_url=CONF_URL, client_kwargs={"scope": "openid email profile"})
@auth_router.get("/")
async def homepage(request: Request):
user = request.session.get("user")
if user:
data = json.dumps(user)
html = f"<pre>{data}</pre>" '<a href="/logout">logout</a>'
return HTMLResponse(html)
return HTMLResponse('<a href="/login">login</a>')
@auth_router.get("/login") @auth_router.get("/login")
async def login_get(request: Request):
redirect_uri = request.url_for("auth")
return await oauth.google.authorize_redirect(request, redirect_uri)
@auth_router.post("/login")
async def login(request: Request): async def login(request: Request):
redirect_uri = request.url_for("auth") redirect_uri = request.url_for("auth")
return await oauth.google.authorize_redirect(request, redirect_uri) return await oauth.google.authorize_redirect(request, redirect_uri)
@auth_router.get("/redirect") @auth_router.post("/redirect")
async def auth(request: Request): async def auth(request: Request):
form = await request.form()
credential = form.get("credential")
csrf_token_cookie = request.cookies.get("g_csrf_token")
if not csrf_token_cookie:
return Response("Missing CSRF token", status_code=400)
csrf_token_body = form.get("g_csrf_token")
if not csrf_token_body:
return Response("Missing CSRF token", status_code=400)
if csrf_token_cookie != csrf_token_body:
return Response("Invalid CSRF token", status_code=400)
try: try:
token = await oauth.google.authorize_access_token(request) idinfo = id_token.verify_oauth2_token(credential, google_requests.Request(), os.environ["GOOGLE_CLIENT_ID"])
except OAuthError as error: except OAuthError as error:
return HTMLResponse(f"<h1>{error.error}</h1>") return HTMLResponse(f"<h1>{error.error}</h1>")
khoj_user = await get_or_create_user(token) khoj_user = await get_or_create_user(idinfo)
user = token.get("userinfo") if khoj_user:
if user: request.session["user"] = dict(idinfo)
request.session["user"] = dict(user)
return RedirectResponse(url="/") return RedirectResponse(url="/")

View file

@ -141,6 +141,7 @@ async def update(
) )
logger.info(f"Finished processing batch indexing request") logger.info(f"Finished processing batch indexing request")
except Exception as e: except Exception as e:
logger.error(f"Failed to process batch indexing request: {e}", exc_info=True)
logger.error( logger.error(
f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}", f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}",
exc_info=True, exc_info=True,
@ -156,7 +157,6 @@ async def update(
host=host, host=host,
) )
logger.info(f"📪 Content index updated via API call by {client} client")
return Response(content="OK", status_code=200) return Response(content="OK", status_code=200)

View file

@ -1,7 +1,11 @@
# System Packages
import json
import os
# External Packages # External Packages
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import Request from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from starlette.authentication import requires from starlette.authentication import requires
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import (
@ -16,9 +20,7 @@ from khoj.utils.rawconfig import (
# Internal Packages # Internal Packages
from khoj.utils import constants, state from khoj.utils import constants, state
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config
from database.models import KhojUser, LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
import json
# Initialize Router # Initialize Router
@ -30,15 +32,41 @@ VALID_TEXT_CONTENT_TYPES = ["org", "markdown", "pdf", "plaintext"]
# Create Routes # Create Routes
@web_client.get("/", response_class=FileResponse) @web_client.get("/", response_class=FileResponse)
@requires(["authenticated"], redirect="login_page")
def index(request: Request): def index(request: Request):
return templates.TemplateResponse("index.html", context={"request": request, "demo": state.demo}) return templates.TemplateResponse("index.html", context={"request": request, "demo": state.demo})
@web_client.post("/", response_class=FileResponse)
@requires(["authenticated"], redirect="login_page")
def index_post(request: Request):
return templates.TemplateResponse("index.html", context={"request": request, "demo": state.demo})
@web_client.get("/chat", response_class=FileResponse) @web_client.get("/chat", response_class=FileResponse)
@requires(["authenticated"], redirect="login_page")
def chat_page(request: Request): def chat_page(request: Request):
return templates.TemplateResponse("chat.html", context={"request": request, "demo": state.demo}) return templates.TemplateResponse("chat.html", context={"request": request, "demo": state.demo})
@web_client.get("/login", response_class=FileResponse)
def login_page(request: Request):
if request.user.is_authenticated:
next_url = request.query_params.get("next", "/")
return RedirectResponse(url=next_url)
google_client_id = os.environ.get("GOOGLE_CLIENT_ID")
redirect_uri = request.url_for("auth")
return templates.TemplateResponse(
"login.html",
context={
"request": request,
"demo": state.demo,
"google_client_id": google_client_id,
"redirect_uri": redirect_uri,
},
)
def map_config_to_object(content_type: str): def map_config_to_object(content_type: str):
if content_type == "org": if content_type == "org":
return LocalOrgConfig return LocalOrgConfig
@ -53,6 +81,7 @@ def map_config_to_object(content_type: str):
if not state.demo: if not state.demo:
@web_client.get("/config", response_class=HTMLResponse) @web_client.get("/config", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def config_page(request: Request): def config_page(request: Request):
user = request.user.object if request.user.is_authenticated else None user = request.user.object if request.user.is_authenticated else None
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all()) enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
@ -97,11 +126,12 @@ if not state.demo:
"request": request, "request": request,
"current_config": current_config, "current_config": current_config,
"current_model_state": successfully_configured, "current_model_state": successfully_configured,
"anonymous_mode": state.anonymous_mode,
}, },
) )
@web_client.get("/config/content_type/github", response_class=HTMLResponse) @web_client.get("/config/content_type/github", response_class=HTMLResponse)
@requires(["authenticated"]) @requires(["authenticated"], redirect="login_page")
def github_config_page(request: Request): def github_config_page(request: Request):
user = request.user.object if request.user.is_authenticated else None user = request.user.object if request.user.is_authenticated else None
current_github_config = get_user_github_config(user) current_github_config = get_user_github_config(user)
@ -130,6 +160,7 @@ if not state.demo:
) )
@web_client.get("/config/content_type/notion", response_class=HTMLResponse) @web_client.get("/config/content_type/notion", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def notion_config_page(request: Request): def notion_config_page(request: Request):
user = request.user.object if request.user.is_authenticated else None user = request.user.object if request.user.is_authenticated else None
current_notion_config = get_user_notion_config(user) current_notion_config = get_user_notion_config(user)
@ -145,6 +176,7 @@ if not state.demo:
) )
@web_client.get("/config/content_type/{content_type}", response_class=HTMLResponse) @web_client.get("/config/content_type/{content_type}", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def content_config_page(request: Request, content_type: str): def content_config_page(request: Request, content_type: str):
if content_type not in VALID_TEXT_CONTENT_TYPES: if content_type not in VALID_TEXT_CONTENT_TYPES:
return templates.TemplateResponse("config.html", context={"request": request}) return templates.TemplateResponse("config.html", context={"request": request})

View file

@ -25,7 +25,6 @@ from khoj.utils.rawconfig import (
OfflineChatProcessorConfig, OfflineChatProcessorConfig,
OpenAIProcessorConfig, OpenAIProcessorConfig,
ProcessorConfig, ProcessorConfig,
TextContentConfig,
ImageContentConfig, ImageContentConfig,
SearchConfig, SearchConfig,
TextSearchConfig, TextSearchConfig,
@ -38,7 +37,6 @@ from database.models import (
LocalOrgConfig, LocalOrgConfig,
LocalMarkdownConfig, LocalMarkdownConfig,
LocalPlaintextConfig, LocalPlaintextConfig,
LocalPdfConfig,
GithubConfig, GithubConfig,
KhojUser, KhojUser,
GithubRepoConfig, GithubRepoConfig,
@ -95,6 +93,19 @@ def default_user():
return UserFactory() return UserFactory()
@pytest.mark.django_db
@pytest.fixture
def default_user2():
if KhojUser.objects.filter(username="default").exists():
return KhojUser.objects.get(username="default")
return UserFactory(
username="default",
email="default@example.com",
password="default",
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def search_models(search_config: SearchConfig): def search_models(search_config: SearchConfig):
search_models = SearchModels() search_models = SearchModels()

View file

@ -7,6 +7,7 @@ import pytest
# External Packages # External Packages
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from fastapi import FastAPI from fastapi import FastAPI
import pytest
# Internal Packages # Internal Packages
from khoj.configure import configure_routes, configure_search_types from khoj.configure import configure_routes, configure_search_types
@ -115,16 +116,11 @@ def test_get_configured_types_via_api(client, sample_org_data):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data): def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser):
# Arrange # Arrange
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False) text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
# Act
response = client.get(f"/api/config/types") response = client.get(f"/api/config/types")
assert response.json() == ["all", "org", "image"]
# Assert
assert response.status_code == 200
assert response.json() == ["all", "org", "markdown", "image"]
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@ -150,6 +146,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig): def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
search_models.image_search = image_search.initialize_model(search_config.image) search_models.image_search = image_search.initialize_model(search_config.image)
@ -177,9 +174,9 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_notes_search(client, search_config: SearchConfig, sample_org_data): def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser):
# Arrange # Arrange
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False) text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
user_query = quote("How to git install application?") user_query = quote("How to git install application?")
# Act # Act
@ -195,13 +192,14 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_notes_search_with_only_filters( def test_notes_search_with_only_filters(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data, default_user2: KhojUser
): ):
# Arrange # Arrange
text_search.setup( text_search.setup(
OrgToJsonl, OrgToJsonl,
sample_org_data, sample_org_data,
regenerate=False, regenerate=False,
user=default_user2,
) )
user_query = quote('+"Emacs" file:"*.org"') user_query = quote('+"Emacs" file:"*.org"')
@ -217,9 +215,9 @@ def test_notes_search_with_only_filters(
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_notes_search_with_include_filter(client, sample_org_data): def test_notes_search_with_include_filter(client, sample_org_data, default_user2: KhojUser):
# Arrange # Arrange
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False) text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
user_query = quote('How to git install application? +"Emacs"') user_query = quote('How to git install application? +"Emacs"')
# Act # Act
@ -234,12 +232,13 @@ def test_notes_search_with_include_filter(client, sample_org_data):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_notes_search_with_exclude_filter(client, sample_org_data): def test_notes_search_with_exclude_filter(client, sample_org_data, default_user2: KhojUser):
# Arrange # Arrange
text_search.setup( text_search.setup(
OrgToJsonl, OrgToJsonl,
sample_org_data, sample_org_data,
regenerate=False, regenerate=False,
user=default_user2,
) )
user_query = quote('How to git install application? -"clone"') user_query = quote('How to git install application? -"clone"')