diff --git a/pyproject.toml b/pyproject.toml index 8732d47a..f9ef020c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,8 @@ dependencies = [ "httpx == 0.25.0", "pgvector == 0.2.3", "psycopg2-binary == 2.9.9", + "google-auth == 2.23.3", + "python-multipart == 0.0.6", ] dynamic = ["version"] diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index a7c1c6f9..fc4f23b1 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -73,8 +73,7 @@ async def create_google_user(token: dict) -> KhojUser: async def get_user_by_token(token: dict) -> KhojUser: - user_info = token.get("userinfo") - google_user = await GoogleUser.objects.filter(sub=user_info.get("sub")).select_related("user").afirst() + google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst() if not google_user: return None return google_user.user diff --git a/src/khoj/configure.py b/src/khoj/configure.py index f65b1056..76b2e9f4 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -67,7 +67,7 @@ class UserAuthenticationBackend(AuthenticationBackend): user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst() if user: return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) - elif not state.anonymous_mode: + elif state.anonymous_mode: user = await self.khojuser_manager.filter(username="default").afirst() if user: return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) @@ -77,11 +77,6 @@ class UserAuthenticationBackend(AuthenticationBackend): def initialize_server(config: Optional[FullConfig]): if config is None: - logger.error( - f"🚨 Exiting as Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config or by editing {state.config_file}." - ) - sys.exit(1) - elif config is None: logger.warning( f"🚨 Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config, plugins or by editing {state.config_file}." ) @@ -230,6 +225,8 @@ def configure_conversation_processor( conversation_logfile=conversation_logfile, openai=(conversation_config.openai if (conversation_config is not None) else None), offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(), + max_prompt_size=conversation_config.max_prompt_size if conversation_config else None, + tokenizer=conversation_config.tokenizer if conversation_config else None, ) ) else: diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html index 5b643d58..15c3f678 100644 --- a/src/khoj/interface/web/base_config.html +++ b/src/khoj/interface/web/base_config.html @@ -211,6 +211,14 @@ grid-gap: 24px; } + button#logout { + font-size: 16px; + cursor: pointer; + margin: 0; + padding: 0; + height: 32px; + } + @media screen and (max-width: 600px) { .section-cards { grid-template-columns: 1fr; @@ -242,6 +250,10 @@ width: 320px; } + div.khoj-header-wrapper{ + grid-template-columns: auto; + } + } diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 6e3a0223..6c69c056 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -297,6 +297,11 @@
+ {% if anonymous_mode == False %} +
+ +
+ {% endif %} diff --git a/src/khoj/interface/web/login.html b/src/khoj/interface/web/login.html new file mode 100644 index 00000000..550991ed --- /dev/null +++ b/src/khoj/interface/web/login.html @@ -0,0 +1,197 @@ + + + + + Khoj - Search + + + + + + + + {% if demo %} + +
+ +

+ Enroll in Khoj cloud to get your own assistant +

+
+ + +
+ {% endif %} + +
+ {% if demo %} + + {% else %} + + {% endif %} +
+ + +
+

Become superhuman with your personal knowledge base copilot

+
+
+
+
+
+ + + + + + + + diff --git a/src/khoj/main.py b/src/khoj/main.py index a713cc97..804e71e5 100644 --- a/src/khoj/main.py +++ b/src/khoj/main.py @@ -14,13 +14,13 @@ warnings.filterwarnings("ignore", message=r"legacy way to download files from th # External Packages import uvicorn +import django from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -import schedule -import django - from fastapi.staticfiles import StaticFiles from rich.logging import RichHandler +import schedule + from django.core.asgi import get_asgi_application from django.core.management import call_command @@ -34,13 +34,6 @@ call_command("migrate", "--noinput") # Initialize Django Static Files call_command("collectstatic", "--noinput") -# Initialize Django -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings") -django.setup() - -# Initialize Django Database -call_command("migrate", "--noinput") - # Initialize the Application Server app = FastAPI() diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 7c3e3392..d041fd76 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -127,6 +127,7 @@ if not state.demo: state.processor_config = configure_processor(state.config.processor) @api.get("/config/data", response_model=FullConfig) + @requires(["authenticated"], redirect="login_page") def get_config_data(request: Request): user = request.user.object if request.user.is_authenticated else None enabled_content = EmbeddingsAdapters.get_unique_file_types(user) @@ -134,6 +135,7 @@ if not state.demo: return state.config @api.post("/config/data") + @requires(["authenticated"], redirect="login_page") async def set_config_data( request: Request, updated_config: FullConfig, @@ -166,7 +168,7 @@ if not state.demo: return state.config @api.post("/config/data/content_type/github", status_code=200) - @requires("authenticated") + @requires(["authenticated"], redirect="login_page") async def set_content_config_github_data( request: Request, updated_config: Union[GithubContentConfig, None], @@ -193,6 +195,7 @@ if not state.demo: return {"status": "ok"} @api.post("/config/data/content_type/notion", status_code=200) + @requires(["authenticated"], redirect="login_page") async def set_content_config_notion_data( request: Request, updated_config: Union[NotionContentConfig, None], @@ -218,6 +221,7 @@ if not state.demo: return {"status": "ok"} @api.post("/delete/config/data/content_type/{content_type}", status_code=200) + @requires(["authenticated"], redirect="login_page") async def remove_content_config_data( request: Request, content_type: str, @@ -272,7 +276,7 @@ if not state.demo: return {"status": "error", "message": str(e)} @api.post("/config/data/content_type/{content_type}", status_code=200) - # @requires("authenticated") + @requires(["authenticated"], redirect="login_page") async def set_content_config_data( request: Request, content_type: str, @@ -378,6 +382,7 @@ def get_default_config_data(): @api.get("/config/types", response_model=List[str]) +@requires(["authenticated"], redirect="login_page") def get_config_types( request: Request, ): @@ -399,6 +404,7 @@ def get_config_types( @api.get("/search", response_model=List[SearchResponse]) +@requires(["authenticated"], redirect="login_page") async def search( q: str, request: Request, @@ -532,6 +538,7 @@ async def search( @api.get("/update") +@requires(["authenticated"], redirect="login_page") def update( request: Request, t: Optional[SearchType] = None, @@ -577,6 +584,7 @@ def update( @api.get("/chat/history") +@requires(["authenticated"], redirect="login_page") def chat_history( request: Request, client: Optional[str] = None, @@ -605,6 +613,7 @@ def chat_history( @api.get("/chat/options", response_class=Response) +@requires(["authenticated"], redirect="login_page") async def chat_options( request: Request, client: Optional[str] = None, @@ -629,6 +638,7 @@ async def chat_options( @api.get("/chat", response_class=Response) +@requires(["authenticated"], redirect="login_page") async def chat( request: Request, q: str, diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py index 41bc8396..8c767d8f 100644 --- a/src/khoj/routers/auth.py +++ b/src/khoj/routers/auth.py @@ -4,9 +4,12 @@ import os from fastapi import APIRouter from starlette.config import Config from starlette.requests import Request -from starlette.responses import HTMLResponse, RedirectResponse +from starlette.responses import HTMLResponse, RedirectResponse, Response from authlib.integrations.starlette_client import OAuth, OAuthError +from google.oauth2 import id_token +from google.auth.transport import requests as google_requests + from database.adapters import get_or_create_user logger = logging.getLogger(__name__) @@ -24,32 +27,40 @@ else: oauth.register(name="google", server_metadata_url=CONF_URL, client_kwargs={"scope": "openid email profile"}) -@auth_router.get("/") -async def homepage(request: Request): - user = request.session.get("user") - if user: - data = json.dumps(user) - html = f"
{data}
" 'logout' - return HTMLResponse(html) - return HTMLResponse('login') - - @auth_router.get("/login") +async def login_get(request: Request): + redirect_uri = request.url_for("auth") + return await oauth.google.authorize_redirect(request, redirect_uri) + + +@auth_router.post("/login") async def login(request: Request): redirect_uri = request.url_for("auth") return await oauth.google.authorize_redirect(request, redirect_uri) -@auth_router.get("/redirect") +@auth_router.post("/redirect") async def auth(request: Request): + form = await request.form() + credential = form.get("credential") + + csrf_token_cookie = request.cookies.get("g_csrf_token") + if not csrf_token_cookie: + return Response("Missing CSRF token", status_code=400) + csrf_token_body = form.get("g_csrf_token") + if not csrf_token_body: + return Response("Missing CSRF token", status_code=400) + if csrf_token_cookie != csrf_token_body: + return Response("Invalid CSRF token", status_code=400) + try: - token = await oauth.google.authorize_access_token(request) + idinfo = id_token.verify_oauth2_token(credential, google_requests.Request(), os.environ["GOOGLE_CLIENT_ID"]) except OAuthError as error: return HTMLResponse(f"

{error.error}

") - khoj_user = await get_or_create_user(token) - user = token.get("userinfo") - if user: - request.session["user"] = dict(user) + khoj_user = await get_or_create_user(idinfo) + if khoj_user: + request.session["user"] = dict(idinfo) + return RedirectResponse(url="/") diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py index c2ef04ff..1e73c439 100644 --- a/src/khoj/routers/indexer.py +++ b/src/khoj/routers/indexer.py @@ -141,6 +141,7 @@ async def update( ) logger.info(f"Finished processing batch indexing request") except Exception as e: + logger.error(f"Failed to process batch indexing request: {e}", exc_info=True) logger.error( f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}", exc_info=True, @@ -156,7 +157,6 @@ async def update( host=host, ) - logger.info(f"📪 Content index updated via API call by {client} client") return Response(content="OK", status_code=200) diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 6c79e061..4122c6d0 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -1,7 +1,11 @@ +# System Packages +import json +import os + # External Packages from fastapi import APIRouter from fastapi import Request -from fastapi.responses import HTMLResponse, FileResponse +from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse from fastapi.templating import Jinja2Templates from starlette.authentication import requires from khoj.utils.rawconfig import ( @@ -16,9 +20,7 @@ from khoj.utils.rawconfig import ( # Internal Packages from khoj.utils import constants, state from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config -from database.models import KhojUser, LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig - -import json +from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig # Initialize Router @@ -30,15 +32,41 @@ VALID_TEXT_CONTENT_TYPES = ["org", "markdown", "pdf", "plaintext"] # Create Routes @web_client.get("/", response_class=FileResponse) +@requires(["authenticated"], redirect="login_page") def index(request: Request): return templates.TemplateResponse("index.html", context={"request": request, "demo": state.demo}) +@web_client.post("/", response_class=FileResponse) +@requires(["authenticated"], redirect="login_page") +def index_post(request: Request): + return templates.TemplateResponse("index.html", context={"request": request, "demo": state.demo}) + + @web_client.get("/chat", response_class=FileResponse) +@requires(["authenticated"], redirect="login_page") def chat_page(request: Request): return templates.TemplateResponse("chat.html", context={"request": request, "demo": state.demo}) +@web_client.get("/login", response_class=FileResponse) +def login_page(request: Request): + if request.user.is_authenticated: + next_url = request.query_params.get("next", "/") + return RedirectResponse(url=next_url) + google_client_id = os.environ.get("GOOGLE_CLIENT_ID") + redirect_uri = request.url_for("auth") + return templates.TemplateResponse( + "login.html", + context={ + "request": request, + "demo": state.demo, + "google_client_id": google_client_id, + "redirect_uri": redirect_uri, + }, + ) + + def map_config_to_object(content_type: str): if content_type == "org": return LocalOrgConfig @@ -53,6 +81,7 @@ def map_config_to_object(content_type: str): if not state.demo: @web_client.get("/config", response_class=HTMLResponse) + @requires(["authenticated"], redirect="login_page") def config_page(request: Request): user = request.user.object if request.user.is_authenticated else None enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all()) @@ -97,11 +126,12 @@ if not state.demo: "request": request, "current_config": current_config, "current_model_state": successfully_configured, + "anonymous_mode": state.anonymous_mode, }, ) @web_client.get("/config/content_type/github", response_class=HTMLResponse) - @requires(["authenticated"]) + @requires(["authenticated"], redirect="login_page") def github_config_page(request: Request): user = request.user.object if request.user.is_authenticated else None current_github_config = get_user_github_config(user) @@ -130,6 +160,7 @@ if not state.demo: ) @web_client.get("/config/content_type/notion", response_class=HTMLResponse) + @requires(["authenticated"], redirect="login_page") def notion_config_page(request: Request): user = request.user.object if request.user.is_authenticated else None current_notion_config = get_user_notion_config(user) @@ -145,6 +176,7 @@ if not state.demo: ) @web_client.get("/config/content_type/{content_type}", response_class=HTMLResponse) + @requires(["authenticated"], redirect="login_page") def content_config_page(request: Request, content_type: str): if content_type not in VALID_TEXT_CONTENT_TYPES: return templates.TemplateResponse("config.html", context={"request": request}) diff --git a/tests/conftest.py b/tests/conftest.py index 5f515ef1..ee4b9e57 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,6 @@ from khoj.utils.rawconfig import ( OfflineChatProcessorConfig, OpenAIProcessorConfig, ProcessorConfig, - TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, @@ -38,7 +37,6 @@ from database.models import ( LocalOrgConfig, LocalMarkdownConfig, LocalPlaintextConfig, - LocalPdfConfig, GithubConfig, KhojUser, GithubRepoConfig, @@ -95,6 +93,19 @@ def default_user(): return UserFactory() +@pytest.mark.django_db +@pytest.fixture +def default_user2(): + if KhojUser.objects.filter(username="default").exists(): + return KhojUser.objects.get(username="default") + + return UserFactory( + username="default", + email="default@example.com", + password="default", + ) + + @pytest.fixture(scope="session") def search_models(search_config: SearchConfig): search_models = SearchModels() diff --git a/tests/test_client.py b/tests/test_client.py index f63b968c..b77ba07d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -7,6 +7,7 @@ import pytest # External Packages from fastapi.testclient import TestClient from fastapi import FastAPI +import pytest # Internal Packages from khoj.configure import configure_routes, configure_search_types @@ -115,16 +116,11 @@ def test_get_configured_types_via_api(client, sample_org_data): # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) -def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data): +def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser): # Arrange - text_search.setup(OrgToJsonl, sample_org_data, regenerate=False) - - # Act + text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2) response = client.get(f"/api/config/types") - - # Assert - assert response.status_code == 200 - assert response.json() == ["all", "org", "markdown", "image"] + assert response.json() == ["all", "org", "image"] # ---------------------------------------------------------------------------------------------------- @@ -150,6 +146,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange search_models.image_search = image_search.initialize_model(search_config.image) @@ -177,9 +174,9 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) -def test_notes_search(client, search_config: SearchConfig, sample_org_data): +def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser): # Arrange - text_search.setup(OrgToJsonl, sample_org_data, regenerate=False) + text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2) user_query = quote("How to git install application?") # Act @@ -195,13 +192,14 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data): # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) def test_notes_search_with_only_filters( - client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data + client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data, default_user2: KhojUser ): # Arrange text_search.setup( OrgToJsonl, sample_org_data, regenerate=False, + user=default_user2, ) user_query = quote('+"Emacs" file:"*.org"') @@ -217,9 +215,9 @@ def test_notes_search_with_only_filters( # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) -def test_notes_search_with_include_filter(client, sample_org_data): +def test_notes_search_with_include_filter(client, sample_org_data, default_user2: KhojUser): # Arrange - text_search.setup(OrgToJsonl, sample_org_data, regenerate=False) + text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2) user_query = quote('How to git install application? +"Emacs"') # Act @@ -234,12 +232,13 @@ def test_notes_search_with_include_filter(client, sample_org_data): # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) -def test_notes_search_with_exclude_filter(client, sample_org_data): +def test_notes_search_with_exclude_filter(client, sample_org_data, default_user2: KhojUser): # Arrange text_search.setup( OrgToJsonl, sample_org_data, regenerate=False, + user=default_user2, ) user_query = quote('How to git install application? -"clone"')