From c125995d94a9f973cf1aed9500418389d5bdf592 Mon Sep 17 00:00:00 2001 From: sabaimran <65192171+sabaimran@users.noreply.github.com> Date: Sat, 14 Oct 2023 19:39:13 -0700 Subject: [PATCH] [Multi-User]: Part 0 - Add support for logging in with Google (#487) * Add concept of user authentication to the request session via GoogleUser --- .gitignore | 1 + Dockerfile | 2 + pyproject.toml | 4 + src/app/__init__.py | 0 src/{khoj => app}/main.py | 36 +++-- src/app/settings.py | 129 ++++++++++++++++++ src/app/urls.py | 25 ++++ src/app/wsgi.py | 16 +++ src/database/__init__.py | 0 src/database/adapters/__init__.py | 78 +++++++++++ src/database/admin.py | 8 ++ src/database/apps.py | 6 + src/database/migrations/0001_khojuser.py | 98 +++++++++++++ src/database/migrations/0002_googleuser.py | 32 +++++ .../0003_user_khoj_configurations_and_more.py | 79 +++++++++++ src/database/migrations/__init__.py | 0 src/database/models/__init__.py | 53 +++++++ src/khoj/configure.py | 47 ++++++- src/khoj/interface/web/index.html | 6 +- src/khoj/routers/auth.py | 59 ++++++++ src/khoj/utils/fs_syncer.py | 4 + src/manage.py | 22 +++ tests/conftest.py | 12 +- tests/test_client.py | 2 +- 24 files changed, 702 insertions(+), 17 deletions(-) create mode 100644 src/app/__init__.py rename src/{khoj => app}/main.py (78%) create mode 100644 src/app/settings.py create mode 100644 src/app/urls.py create mode 100644 src/app/wsgi.py create mode 100644 src/database/__init__.py create mode 100644 src/database/adapters/__init__.py create mode 100644 src/database/admin.py create mode 100644 src/database/apps.py create mode 100644 src/database/migrations/0001_khojuser.py create mode 100644 src/database/migrations/0002_googleuser.py create mode 100644 src/database/migrations/0003_user_khoj_configurations_and_more.py create mode 100644 src/database/migrations/__init__.py create mode 100644 src/database/models/__init__.py create mode 100644 src/khoj/routers/auth.py create mode 100755 src/manage.py diff --git a/.gitignore b/.gitignore index 8e99392c..e3e93428 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ khoj_assistant.egg-info /config/khoj*.yml .pytest_cache khoj.log +static # Obsidian plugin artifacts # --- diff --git a/Dockerfile b/Dockerfile index bdf9647f..af271537 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,6 +5,8 @@ LABEL org.opencontainers.image.source https://github.com/khoj-ai/khoj # Install System Dependencies RUN apt update -y && apt -y install python3-pip git +WORKDIR /app + # Install Application COPY . . RUN sed -i 's/dynamic = \["version"\]/version = "0.0.0"/' pyproject.toml && \ diff --git a/pyproject.toml b/pyproject.toml index a52fc9b6..12be01cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,8 +59,12 @@ dependencies = [ "bs4 >= 0.0.1", "anyio == 3.7.1", "pymupdf >= 1.23.3", + "django == 4.2.5", + "authlib == 1.2.1", "gpt4all == 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'", "gpt4all == 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'", + "itsdangerous == 2.1.2", + "httpx == 0.25.0", ] dynamic = ["version"] diff --git a/src/app/__init__.py b/src/app/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/khoj/main.py b/src/app/main.py similarity index 78% rename from src/khoj/main.py rename to src/app/main.py index 6710ed05..0049f157 100644 --- a/src/khoj/main.py +++ b/src/app/main.py @@ -3,11 +3,6 @@ import os import sys import locale -if sys.stdout is None: - sys.stdout = open(os.devnull, "w") -if sys.stderr is None: - sys.stderr = open(os.devnull, "w") - import logging import threading import warnings @@ -19,18 +14,33 @@ warnings.filterwarnings("ignore", message=r"legacy way to download files from th # External Packages import uvicorn -from fastapi import FastAPI -from rich.logging import RichHandler +import django import schedule +from fastapi import FastAPI +from fastapi.staticfiles import StaticFiles +from rich.logging import RichHandler +from django.core.asgi import get_asgi_application +from django.core.management import call_command + # Internal Packages -from khoj.configure import configure_routes, initialize_server +from khoj.configure import configure_routes, initialize_server, configure_middleware from khoj.utils import state from khoj.utils.cli import cli +# Initialize Django +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings") +django.setup() + +# Initialize Django Database +call_command("migrate", "--noinput") + # Initialize the Application Server app = FastAPI() +# Get Django Application +django_app = get_asgi_application() + # Set Locale locale.setlocale(locale.LC_ALL, "") @@ -72,7 +82,15 @@ def run(): # Start Server configure_routes(app) - initialize_server(args.config, required=False) + + # Mount Django and Static Files + app.mount("/django", django_app, name="django") + app.mount("/static", StaticFiles(directory="static"), name="static") + + # Configure Middleware + configure_middleware(app) + + initialize_server(args.config) start_server(app, host=args.host, port=args.port, socket=args.socket) diff --git a/src/app/settings.py b/src/app/settings.py new file mode 100644 index 00000000..74c496a7 --- /dev/null +++ b/src/app/settings.py @@ -0,0 +1,129 @@ +""" +Django settings for app project. + +Generated by 'django-admin startproject' using Django 4.2.5. + +For more information on this file, see +https://docs.djangoproject.com/en/4.2/topics/settings/ + +For the full list of settings and their values, see +https://docs.djangoproject.com/en/4.2/ref/settings/ +""" + +from pathlib import Path +import os + +# Build paths inside the project like this: BASE_DIR / 'subdir'. +BASE_DIR = Path(__file__).resolve().parent.parent + + +# Quick-start development settings - unsuitable for production +# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/ + +# SECURITY WARNING: keep the secret key used in production secret! +SECRET_KEY = os.getenv("DJANGO_SECRET_KEY") + +# SECURITY WARNING: don't run with debug turned on in production! +DEBUG = True + +ALLOWED_HOSTS = [] + + +# Application definition + +INSTALLED_APPS = [ + "django.contrib.auth", + "django.contrib.contenttypes", + "database.apps.DatabaseConfig", + "django.contrib.admin", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", +] + +MIDDLEWARE = [ + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", +] + +ROOT_URLCONF = "app.urls" + +TEMPLATES = [ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "APP_DIRS": True, + "DIRS": [os.path.join(BASE_DIR, "templates"), os.path.join(BASE_DIR, "templates", "account")], + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", + ], + }, + }, +] + +WSGI_APPLICATION = "app.wsgi.application" + + +# Database +# https://docs.djangoproject.com/en/4.2/ref/settings/#databases + +DATABASES = { + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": BASE_DIR / "db.sqlite3", + } +} + +# User Settings +AUTH_USER_MODEL = "database.KhojUser" + +# Password validation +# https://docs.djangoproject.com/en/4.2/ref/settings/#auth-password-validators + +AUTH_PASSWORD_VALIDATORS = [ + { + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator", + }, +] + + +# Internationalization +# https://docs.djangoproject.com/en/4.2/topics/i18n/ + +LANGUAGE_CODE = "en-us" + +TIME_ZONE = "UTC" + +USE_I18N = True + +USE_TZ = True + + +# Static files (CSS, JavaScript, Images) +# https://docs.djangoproject.com/en/4.2/howto/static-files/ + +STATIC_ROOT = os.path.join(BASE_DIR, "static") +STATICFILES_DIRS = [os.path.join(BASE_DIR, "khoj/interface/web")] +STATIC_URL = "/static/" + +# Default primary key field type +# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field + +DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" diff --git a/src/app/urls.py b/src/app/urls.py new file mode 100644 index 00000000..fbd67a4e --- /dev/null +++ b/src/app/urls.py @@ -0,0 +1,25 @@ +""" +URL configuration for app project. + +The `urlpatterns` list routes URLs to views. For more information please see: + https://docs.djangoproject.com/en/4.2/topics/http/urls/ +Examples: +Function views + 1. Add an import: from my_app import views + 2. Add a URL to urlpatterns: path('', views.home, name='home') +Class-based views + 1. Add an import: from other_app.views import Home + 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') +Including another URLconf + 1. Import the include() function: from django.urls import include, path + 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) +""" +from django.contrib import admin +from django.urls import path, include +from django.contrib.staticfiles.urls import staticfiles_urlpatterns + +urlpatterns = [ + path("admin/", admin.site.urls), +] + +urlpatterns += staticfiles_urlpatterns() diff --git a/src/app/wsgi.py b/src/app/wsgi.py new file mode 100644 index 00000000..cbdf4342 --- /dev/null +++ b/src/app/wsgi.py @@ -0,0 +1,16 @@ +""" +WSGI config for app project. + +It exposes the WSGI callable as a module-level variable named ``application``. + +For more information on this file, see +https://docs.djangoproject.com/en/4.2/howto/deployment/wsgi/ +""" + +import os + +from django.core.wsgi import get_wsgi_application + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings") + +application = get_wsgi_application() diff --git a/src/database/__init__.py b/src/database/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py new file mode 100644 index 00000000..a72323ae --- /dev/null +++ b/src/database/adapters/__init__.py @@ -0,0 +1,78 @@ +from typing import Type, TypeVar +import uuid + +from django.db import models +from django.contrib.sessions.backends.db import SessionStore + +# Import sync_to_async from Django Channels +from asgiref.sync import sync_to_async + +from fastapi import HTTPException + +from database.models import KhojUser, GoogleUser, NotionConfig + +ModelType = TypeVar("ModelType", bound=models.Model) + + +async def retrieve_object(model_class: Type[ModelType], id: int) -> ModelType: + instance = await model_class.objects.filter(id=id).afirst() + if not instance: + raise HTTPException(status_code=404, detail=f"{model_class.__name__} not found") + return instance + + +async def set_notion_config(token: str, user: KhojUser): + notion_config = await NotionConfig.objects.filter(user=user).afirst() + if not notion_config: + notion_config = await NotionConfig.objects.acreate(token=token, user=user) + else: + notion_config.token = token + await notion_config.asave() + return notion_config + + +async def get_or_create_user(token: dict) -> KhojUser: + user = await get_user_by_token(token) + if not user: + user = await create_google_user(token) + return user + + +async def create_google_user(token: dict) -> KhojUser: + user_info = token.get("userinfo") + user = await KhojUser.objects.acreate( + username=user_info.get("email"), email=user_info.get("email"), uuid=uuid.uuid4() + ) + await user.asave() + await GoogleUser.objects.acreate( + sub=user_info.get("sub"), + azp=user_info.get("azp"), + email=user_info.get("email"), + name=user_info.get("name"), + given_name=user_info.get("given_name"), + family_name=user_info.get("family_name"), + picture=user_info.get("picture"), + locale=user_info.get("locale"), + user=user, + ) + + return user + + +async def get_user_by_token(token: dict) -> KhojUser: + user_info = token.get("userinfo") + google_user = await GoogleUser.objects.filter(sub=user_info.get("sub")).select_related("user").afirst() + if not google_user: + return None + return google_user.user + + +async def retrieve_user(session_id: str) -> KhojUser: + session = SessionStore(session_key=session_id) + if not await sync_to_async(session.exists)(session_key=session_id): + raise HTTPException(status_code=401, detail="Invalid session") + session_data = await sync_to_async(session.load)() + user = await KhojUser.objects.filter(id=session_data.get("_auth_user_id")).afirst() + if not user: + raise HTTPException(status_code=401, detail="Invalid user") + return user diff --git a/src/database/admin.py b/src/database/admin.py new file mode 100644 index 00000000..d09b0ea6 --- /dev/null +++ b/src/database/admin.py @@ -0,0 +1,8 @@ +from django.contrib import admin +from django.contrib.auth.admin import UserAdmin + +# Register your models here. + +from database.models import KhojUser + +admin.site.register(KhojUser, UserAdmin) diff --git a/src/database/apps.py b/src/database/apps.py new file mode 100644 index 00000000..a3b71b13 --- /dev/null +++ b/src/database/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class DatabaseConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "database" diff --git a/src/database/migrations/0001_khojuser.py b/src/database/migrations/0001_khojuser.py new file mode 100644 index 00000000..f1420575 --- /dev/null +++ b/src/database/migrations/0001_khojuser.py @@ -0,0 +1,98 @@ +# Generated by Django 4.2.5 on 2023-09-14 19:00 + +import django.contrib.auth.models +import django.contrib.auth.validators +from django.db import migrations, models +import django.utils.timezone + + +class Migration(migrations.Migration): + initial = True + + dependencies = [ + ("auth", "0012_alter_user_first_name_max_length"), + ] + + run_before = [ + ("admin", "0001_initial"), + ] + + operations = [ + migrations.CreateModel( + name="KhojUser", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("password", models.CharField(max_length=128, verbose_name="password")), + ("last_login", models.DateTimeField(blank=True, null=True, verbose_name="last login")), + ( + "is_superuser", + models.BooleanField( + default=False, + help_text="Designates that this user has all permissions without explicitly assigning them.", + verbose_name="superuser status", + ), + ), + ( + "username", + models.CharField( + error_messages={"unique": "A user with that username already exists."}, + help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.", + max_length=150, + unique=True, + validators=[django.contrib.auth.validators.UnicodeUsernameValidator()], + verbose_name="username", + ), + ), + ("first_name", models.CharField(blank=True, max_length=150, verbose_name="first name")), + ("last_name", models.CharField(blank=True, max_length=150, verbose_name="last name")), + ("email", models.EmailField(blank=True, max_length=254, verbose_name="email address")), + ( + "is_staff", + models.BooleanField( + default=False, + help_text="Designates whether the user can log into this admin site.", + verbose_name="staff status", + ), + ), + ( + "is_active", + models.BooleanField( + default=True, + help_text="Designates whether this user should be treated as active. Unselect this instead of deleting accounts.", + verbose_name="active", + ), + ), + ("date_joined", models.DateTimeField(default=django.utils.timezone.now, verbose_name="date joined")), + ( + "groups", + models.ManyToManyField( + blank=True, + help_text="The groups this user belongs to. A user will get all permissions granted to each of their groups.", + related_name="user_set", + related_query_name="user", + to="auth.group", + verbose_name="groups", + ), + ), + ( + "user_permissions", + models.ManyToManyField( + blank=True, + help_text="Specific permissions for this user.", + related_name="user_set", + related_query_name="user", + to="auth.permission", + verbose_name="user permissions", + ), + ), + ], + options={ + "verbose_name": "user", + "verbose_name_plural": "users", + "abstract": False, + }, + managers=[ + ("objects", django.contrib.auth.models.UserManager()), + ], + ), + ] diff --git a/src/database/migrations/0002_googleuser.py b/src/database/migrations/0002_googleuser.py new file mode 100644 index 00000000..478770d6 --- /dev/null +++ b/src/database/migrations/0002_googleuser.py @@ -0,0 +1,32 @@ +# Generated by Django 4.2.4 on 2023-09-18 23:24 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0001_khojuser"), + ] + + operations = [ + migrations.CreateModel( + name="GoogleUser", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("sub", models.CharField(max_length=200)), + ("azp", models.CharField(max_length=200)), + ("email", models.CharField(max_length=200)), + ("name", models.CharField(max_length=200)), + ("given_name", models.CharField(max_length=200)), + ("family_name", models.CharField(max_length=200)), + ("picture", models.CharField(max_length=200)), + ("locale", models.CharField(max_length=200)), + ( + "user", + models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL), + ), + ], + ), + ] diff --git a/src/database/migrations/0003_user_khoj_configurations_and_more.py b/src/database/migrations/0003_user_khoj_configurations_and_more.py new file mode 100644 index 00000000..537ba4c4 --- /dev/null +++ b/src/database/migrations/0003_user_khoj_configurations_and_more.py @@ -0,0 +1,79 @@ +# Generated by Django 4.2.5 on 2023-09-27 17:52 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0002_googleuser"), + ] + + operations = [ + migrations.CreateModel( + name="Configuration", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ], + ), + migrations.CreateModel( + name="ConversationProcessorConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("conversation", models.JSONField()), + ("enable_offline_chat", models.BooleanField(default=False)), + ], + ), + migrations.CreateModel( + name="GithubConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("pat_token", models.CharField(max_length=200)), + ("compressed_jsonl", models.CharField(max_length=300)), + ("embeddings_file", models.CharField(max_length=300)), + ( + "config", + models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"), + ), + ], + ), + migrations.AddField( + model_name="khojuser", + name="uuid", + field=models.UUIDField(verbose_name=models.UUIDField(default=uuid.uuid4, editable=False)), + preserve_default=False, + ), + migrations.CreateModel( + name="NotionConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("token", models.CharField(max_length=200)), + ("compressed_jsonl", models.CharField(max_length=300)), + ("embeddings_file", models.CharField(max_length=300)), + ( + "config", + models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"), + ), + ], + ), + migrations.CreateModel( + name="GithubRepoConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("name", models.CharField(max_length=200)), + ("owner", models.CharField(max_length=200)), + ("branch", models.CharField(max_length=200)), + ( + "github_config", + models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.githubconfig"), + ), + ], + ), + migrations.AddField( + model_name="configuration", + name="user", + field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL), + ), + ] diff --git a/src/database/migrations/__init__.py b/src/database/migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py new file mode 100644 index 00000000..6536671b --- /dev/null +++ b/src/database/models/__init__.py @@ -0,0 +1,53 @@ +import uuid + +from django.db import models +from django.contrib.auth.models import AbstractUser + + +class KhojUser(AbstractUser): + uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False)) + + +class GoogleUser(models.Model): + user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) + sub = models.CharField(max_length=200) + azp = models.CharField(max_length=200) + email = models.CharField(max_length=200) + name = models.CharField(max_length=200) + given_name = models.CharField(max_length=200) + family_name = models.CharField(max_length=200) + picture = models.CharField(max_length=200) + locale = models.CharField(max_length=200) + + def __str__(self): + return self.name + + +class Configuration(models.Model): + user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) + + +class NotionConfig(models.Model): + token = models.CharField(max_length=200) + compressed_jsonl = models.CharField(max_length=300) + embeddings_file = models.CharField(max_length=300) + config = models.OneToOneField(Configuration, on_delete=models.CASCADE) + + +class GithubConfig(models.Model): + pat_token = models.CharField(max_length=200) + compressed_jsonl = models.CharField(max_length=300) + embeddings_file = models.CharField(max_length=300) + config = models.OneToOneField(Configuration, on_delete=models.CASCADE) + + +class GithubRepoConfig(models.Model): + name = models.CharField(max_length=200) + owner = models.CharField(max_length=200) + branch = models.CharField(max_length=200) + github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE) + + +class ConversationProcessorConfig(models.Model): + conversation = models.JSONField() + enable_offline_chat = models.BooleanField(default=False) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 7e6cc409..e0a06601 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -5,10 +5,19 @@ import json from enum import Enum from typing import Optional import requests +import os # External Packages import schedule -from fastapi.staticfiles import StaticFiles +from starlette.middleware.sessions import SessionMiddleware +from starlette.middleware.authentication import AuthenticationMiddleware + +from starlette.authentication import ( + AuthCredentials, + AuthenticationBackend, + SimpleUser, + UnauthenticatedUser, +) # Internal Packages from khoj.utils import constants, state @@ -26,8 +35,32 @@ from khoj.routers.indexer import configure_content, load_content, configure_sear logger = logging.getLogger(__name__) -def initialize_server(config: Optional[FullConfig], required=False): - if config is None and required: +class AuthenticatedKhojUser(SimpleUser): + def __init__(self, user): + self.object = user + super().__init__(user.email) + + +class UserAuthenticationBackend(AuthenticationBackend): + def __init__( + self, + ): + from database.models import KhojUser + + self.khojuser_manager = KhojUser.objects + super().__init__() + + async def authenticate(self, request): + current_user = request.session.get("user") + if current_user and current_user.get("email"): + user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst() + if user: + return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) + return AuthCredentials(), UnauthenticatedUser() + + +def initialize_server(config: Optional[FullConfig]): + if config is None: logger.error( f"🚨 Exiting as Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config or by editing {state.config_file}." ) @@ -99,12 +132,18 @@ def configure_routes(app): from khoj.routers.api_beta import api_beta from khoj.routers.web_client import web_client from khoj.routers.indexer import indexer + from khoj.routers.auth import auth_router - app.mount("/static", StaticFiles(directory=constants.web_directory), name="static") app.include_router(api, prefix="/api") app.include_router(api_beta, prefix="/api/beta") app.include_router(indexer, prefix="/v1/indexer") app.include_router(web_client) + app.include_router(auth_router, prefix="/auth") + + +def configure_middleware(app): + app.add_middleware(AuthenticationMiddleware, backend=UserAuthenticationBackend()) + app.add_middleware(SessionMiddleware, secret_key=os.environ.get("KHOJ_DJANGO_SECRET_KEY", "!secret")) if not state.demo: diff --git a/src/khoj/interface/web/index.html b/src/khoj/interface/web/index.html index cb2bae49..581ed9b8 100644 --- a/src/khoj/interface/web/index.html +++ b/src/khoj/interface/web/index.html @@ -170,7 +170,11 @@ // Execute Search and Render Results url = createRequestUrl(query, type, results_count || 5, rerank); - fetch(url) + fetch(url, { + headers: { + "X-CSRFToken": csrfToken + } + }) .then(response => response.json()) .then(data => { console.log(data); diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py new file mode 100644 index 00000000..41bc8396 --- /dev/null +++ b/src/khoj/routers/auth.py @@ -0,0 +1,59 @@ +import logging +import json +import os +from fastapi import APIRouter +from starlette.config import Config +from starlette.requests import Request +from starlette.responses import HTMLResponse, RedirectResponse +from authlib.integrations.starlette_client import OAuth, OAuthError + +from database.adapters import get_or_create_user + +logger = logging.getLogger(__name__) + +auth_router = APIRouter() + +if not os.environ.get("GOOGLE_CLIENT_ID") or not os.environ.get("GOOGLE_CLIENT_SECRET"): + logger.info("Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET environment variables to use Google OAuth") +else: + config = Config(environ=os.environ) + + oauth = OAuth(config) + + CONF_URL = "https://accounts.google.com/.well-known/openid-configuration" + oauth.register(name="google", server_metadata_url=CONF_URL, client_kwargs={"scope": "openid email profile"}) + + +@auth_router.get("/") +async def homepage(request: Request): + user = request.session.get("user") + if user: + data = json.dumps(user) + html = f"
{data}
" 'logout' + return HTMLResponse(html) + return HTMLResponse('login') + + +@auth_router.get("/login") +async def login(request: Request): + redirect_uri = request.url_for("auth") + return await oauth.google.authorize_redirect(request, redirect_uri) + + +@auth_router.get("/redirect") +async def auth(request: Request): + try: + token = await oauth.google.authorize_access_token(request) + except OAuthError as error: + return HTMLResponse(f"

{error.error}

") + khoj_user = await get_or_create_user(token) + user = token.get("userinfo") + if user: + request.session["user"] = dict(user) + return RedirectResponse(url="/") + + +@auth_router.get("/logout") +async def logout(request: Request): + request.session.pop("user", None) + return RedirectResponse(url="/") diff --git a/src/khoj/utils/fs_syncer.py b/src/khoj/utils/fs_syncer.py index d303d39b..6a777bd7 100644 --- a/src/khoj/utils/fs_syncer.py +++ b/src/khoj/utils/fs_syncer.py @@ -13,6 +13,10 @@ logger = logging.getLogger(__name__) def collect_files(config: ContentConfig, search_type: Optional[SearchType] = SearchType.All): files = {} + + if config is None: + return files + if search_type == SearchType.All or search_type == SearchType.Org: files["org"] = get_org_files(config.org) if config.org else {} if search_type == SearchType.All or search_type == SearchType.Markdown: diff --git a/src/manage.py b/src/manage.py new file mode 100755 index 00000000..1a64b14a --- /dev/null +++ b/src/manage.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +"""Django's command-line utility for administrative tasks.""" +import os +import sys + + +def main(): + """Run administrative tasks.""" + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings") + try: + from django.core.management import execute_from_command_line + except ImportError as exc: + raise ImportError( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) from exc + execute_from_command_line(sys.argv) + + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py index d851341d..7c1878a2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,14 +4,16 @@ from copy import deepcopy from fastapi.testclient import TestClient from pathlib import Path import pytest +from fastapi.staticfiles import StaticFiles # Internal Packages -from khoj.main import app -from khoj.configure import configure_processor, configure_routes, configure_search_types +from app.main import app +from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl from khoj.search_type import image_search, text_search from khoj.utils.config import SearchModels +from khoj.utils.constants import web_directory from khoj.utils.helpers import resolve_absolute_path from khoj.utils.rawconfig import ( ContentConfig, @@ -231,6 +233,8 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p state.processor_config = configure_processor(processor_config) configure_routes(app) + configure_middleware(app) + app.mount("/static", StaticFiles(directory=web_directory), name="static") return TestClient(app) @@ -264,6 +268,8 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor state.processor_config = configure_processor(processor_config) configure_routes(app) + configure_middleware(app) + app.mount("/static", StaticFiles(directory=web_directory), name="static") return TestClient(app) @@ -292,6 +298,8 @@ def client_offline_chat( state.processor_config = configure_processor(processor_config_offline_chat) configure_routes(app) + configure_middleware(app) + app.mount("/static", StaticFiles(directory=web_directory), name="static") return TestClient(app) diff --git a/tests/test_client.py b/tests/test_client.py index d2497f73..784c765c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,7 +8,7 @@ from urllib.parse import quote from fastapi.testclient import TestClient # Internal Packages -from khoj.main import app +from app.main import app from khoj.configure import configure_routes, configure_search_types from khoj.utils import state from khoj.utils.state import search_models, content_index, config