mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
[Multi-User]: Part 0 - Add support for logging in with Google (#487)
* Add concept of user authentication to the request session via GoogleUser
This commit is contained in:
parent
4a5ed7f06c
commit
c125995d94
24 changed files with 702 additions and 17 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -22,6 +22,7 @@ khoj_assistant.egg-info
|
||||||
/config/khoj*.yml
|
/config/khoj*.yml
|
||||||
.pytest_cache
|
.pytest_cache
|
||||||
khoj.log
|
khoj.log
|
||||||
|
static
|
||||||
|
|
||||||
# Obsidian plugin artifacts
|
# Obsidian plugin artifacts
|
||||||
# ---
|
# ---
|
||||||
|
|
|
@ -5,6 +5,8 @@ LABEL org.opencontainers.image.source https://github.com/khoj-ai/khoj
|
||||||
# Install System Dependencies
|
# Install System Dependencies
|
||||||
RUN apt update -y && apt -y install python3-pip git
|
RUN apt update -y && apt -y install python3-pip git
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
# Install Application
|
# Install Application
|
||||||
COPY . .
|
COPY . .
|
||||||
RUN sed -i 's/dynamic = \["version"\]/version = "0.0.0"/' pyproject.toml && \
|
RUN sed -i 's/dynamic = \["version"\]/version = "0.0.0"/' pyproject.toml && \
|
||||||
|
|
|
@ -59,8 +59,12 @@ dependencies = [
|
||||||
"bs4 >= 0.0.1",
|
"bs4 >= 0.0.1",
|
||||||
"anyio == 3.7.1",
|
"anyio == 3.7.1",
|
||||||
"pymupdf >= 1.23.3",
|
"pymupdf >= 1.23.3",
|
||||||
|
"django == 4.2.5",
|
||||||
|
"authlib == 1.2.1",
|
||||||
"gpt4all == 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
"gpt4all == 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||||
"gpt4all == 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'",
|
"gpt4all == 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'",
|
||||||
|
"itsdangerous == 2.1.2",
|
||||||
|
"httpx == 0.25.0",
|
||||||
]
|
]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
|
0
src/app/__init__.py
Normal file
0
src/app/__init__.py
Normal file
|
@ -3,11 +3,6 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import locale
|
import locale
|
||||||
|
|
||||||
if sys.stdout is None:
|
|
||||||
sys.stdout = open(os.devnull, "w")
|
|
||||||
if sys.stderr is None:
|
|
||||||
sys.stderr = open(os.devnull, "w")
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -19,18 +14,33 @@ warnings.filterwarnings("ignore", message=r"legacy way to download files from th
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
import django
|
||||||
from rich.logging import RichHandler
|
|
||||||
import schedule
|
import schedule
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from rich.logging import RichHandler
|
||||||
|
from django.core.asgi import get_asgi_application
|
||||||
|
from django.core.management import call_command
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_routes, initialize_server
|
from khoj.configure import configure_routes, initialize_server, configure_middleware
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.cli import cli
|
from khoj.utils.cli import cli
|
||||||
|
|
||||||
|
# Initialize Django
|
||||||
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
|
||||||
|
django.setup()
|
||||||
|
|
||||||
|
# Initialize Django Database
|
||||||
|
call_command("migrate", "--noinput")
|
||||||
|
|
||||||
# Initialize the Application Server
|
# Initialize the Application Server
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Get Django Application
|
||||||
|
django_app = get_asgi_application()
|
||||||
|
|
||||||
# Set Locale
|
# Set Locale
|
||||||
locale.setlocale(locale.LC_ALL, "")
|
locale.setlocale(locale.LC_ALL, "")
|
||||||
|
|
||||||
|
@ -72,7 +82,15 @@ def run():
|
||||||
|
|
||||||
# Start Server
|
# Start Server
|
||||||
configure_routes(app)
|
configure_routes(app)
|
||||||
initialize_server(args.config, required=False)
|
|
||||||
|
# Mount Django and Static Files
|
||||||
|
app.mount("/django", django_app, name="django")
|
||||||
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||||
|
|
||||||
|
# Configure Middleware
|
||||||
|
configure_middleware(app)
|
||||||
|
|
||||||
|
initialize_server(args.config)
|
||||||
start_server(app, host=args.host, port=args.port, socket=args.socket)
|
start_server(app, host=args.host, port=args.port, socket=args.socket)
|
||||||
|
|
||||||
|
|
129
src/app/settings.py
Normal file
129
src/app/settings.py
Normal file
|
@ -0,0 +1,129 @@
|
||||||
|
"""
|
||||||
|
Django settings for app project.
|
||||||
|
|
||||||
|
Generated by 'django-admin startproject' using Django 4.2.5.
|
||||||
|
|
||||||
|
For more information on this file, see
|
||||||
|
https://docs.djangoproject.com/en/4.2/topics/settings/
|
||||||
|
|
||||||
|
For the full list of settings and their values, see
|
||||||
|
https://docs.djangoproject.com/en/4.2/ref/settings/
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Build paths inside the project like this: BASE_DIR / 'subdir'.
|
||||||
|
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
# Quick-start development settings - unsuitable for production
|
||||||
|
# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/
|
||||||
|
|
||||||
|
# SECURITY WARNING: keep the secret key used in production secret!
|
||||||
|
SECRET_KEY = os.getenv("DJANGO_SECRET_KEY")
|
||||||
|
|
||||||
|
# SECURITY WARNING: don't run with debug turned on in production!
|
||||||
|
DEBUG = True
|
||||||
|
|
||||||
|
ALLOWED_HOSTS = []
|
||||||
|
|
||||||
|
|
||||||
|
# Application definition
|
||||||
|
|
||||||
|
INSTALLED_APPS = [
|
||||||
|
"django.contrib.auth",
|
||||||
|
"django.contrib.contenttypes",
|
||||||
|
"database.apps.DatabaseConfig",
|
||||||
|
"django.contrib.admin",
|
||||||
|
"django.contrib.sessions",
|
||||||
|
"django.contrib.messages",
|
||||||
|
"django.contrib.staticfiles",
|
||||||
|
]
|
||||||
|
|
||||||
|
MIDDLEWARE = [
|
||||||
|
"django.middleware.security.SecurityMiddleware",
|
||||||
|
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||||
|
"django.middleware.common.CommonMiddleware",
|
||||||
|
"django.middleware.csrf.CsrfViewMiddleware",
|
||||||
|
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||||
|
"django.contrib.messages.middleware.MessageMiddleware",
|
||||||
|
"django.middleware.clickjacking.XFrameOptionsMiddleware",
|
||||||
|
]
|
||||||
|
|
||||||
|
ROOT_URLCONF = "app.urls"
|
||||||
|
|
||||||
|
TEMPLATES = [
|
||||||
|
{
|
||||||
|
"BACKEND": "django.template.backends.django.DjangoTemplates",
|
||||||
|
"APP_DIRS": True,
|
||||||
|
"DIRS": [os.path.join(BASE_DIR, "templates"), os.path.join(BASE_DIR, "templates", "account")],
|
||||||
|
"OPTIONS": {
|
||||||
|
"context_processors": [
|
||||||
|
"django.template.context_processors.debug",
|
||||||
|
"django.template.context_processors.request",
|
||||||
|
"django.contrib.auth.context_processors.auth",
|
||||||
|
"django.contrib.messages.context_processors.messages",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
WSGI_APPLICATION = "app.wsgi.application"
|
||||||
|
|
||||||
|
|
||||||
|
# Database
|
||||||
|
# https://docs.djangoproject.com/en/4.2/ref/settings/#databases
|
||||||
|
|
||||||
|
DATABASES = {
|
||||||
|
"default": {
|
||||||
|
"ENGINE": "django.db.backends.sqlite3",
|
||||||
|
"NAME": BASE_DIR / "db.sqlite3",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# User Settings
|
||||||
|
AUTH_USER_MODEL = "database.KhojUser"
|
||||||
|
|
||||||
|
# Password validation
|
||||||
|
# https://docs.djangoproject.com/en/4.2/ref/settings/#auth-password-validators
|
||||||
|
|
||||||
|
AUTH_PASSWORD_VALIDATORS = [
|
||||||
|
{
|
||||||
|
"NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Internationalization
|
||||||
|
# https://docs.djangoproject.com/en/4.2/topics/i18n/
|
||||||
|
|
||||||
|
LANGUAGE_CODE = "en-us"
|
||||||
|
|
||||||
|
TIME_ZONE = "UTC"
|
||||||
|
|
||||||
|
USE_I18N = True
|
||||||
|
|
||||||
|
USE_TZ = True
|
||||||
|
|
||||||
|
|
||||||
|
# Static files (CSS, JavaScript, Images)
|
||||||
|
# https://docs.djangoproject.com/en/4.2/howto/static-files/
|
||||||
|
|
||||||
|
STATIC_ROOT = os.path.join(BASE_DIR, "static")
|
||||||
|
STATICFILES_DIRS = [os.path.join(BASE_DIR, "khoj/interface/web")]
|
||||||
|
STATIC_URL = "/static/"
|
||||||
|
|
||||||
|
# Default primary key field type
|
||||||
|
# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field
|
||||||
|
|
||||||
|
DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
|
25
src/app/urls.py
Normal file
25
src/app/urls.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
"""
|
||||||
|
URL configuration for app project.
|
||||||
|
|
||||||
|
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||||
|
https://docs.djangoproject.com/en/4.2/topics/http/urls/
|
||||||
|
Examples:
|
||||||
|
Function views
|
||||||
|
1. Add an import: from my_app import views
|
||||||
|
2. Add a URL to urlpatterns: path('', views.home, name='home')
|
||||||
|
Class-based views
|
||||||
|
1. Add an import: from other_app.views import Home
|
||||||
|
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
|
||||||
|
Including another URLconf
|
||||||
|
1. Import the include() function: from django.urls import include, path
|
||||||
|
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
|
||||||
|
"""
|
||||||
|
from django.contrib import admin
|
||||||
|
from django.urls import path, include
|
||||||
|
from django.contrib.staticfiles.urls import staticfiles_urlpatterns
|
||||||
|
|
||||||
|
urlpatterns = [
|
||||||
|
path("admin/", admin.site.urls),
|
||||||
|
]
|
||||||
|
|
||||||
|
urlpatterns += staticfiles_urlpatterns()
|
16
src/app/wsgi.py
Normal file
16
src/app/wsgi.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
"""
|
||||||
|
WSGI config for app project.
|
||||||
|
|
||||||
|
It exposes the WSGI callable as a module-level variable named ``application``.
|
||||||
|
|
||||||
|
For more information on this file, see
|
||||||
|
https://docs.djangoproject.com/en/4.2/howto/deployment/wsgi/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from django.core.wsgi import get_wsgi_application
|
||||||
|
|
||||||
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
|
||||||
|
|
||||||
|
application = get_wsgi_application()
|
0
src/database/__init__.py
Normal file
0
src/database/__init__.py
Normal file
78
src/database/adapters/__init__.py
Normal file
78
src/database/adapters/__init__.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
from typing import Type, TypeVar
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from django.db import models
|
||||||
|
from django.contrib.sessions.backends.db import SessionStore
|
||||||
|
|
||||||
|
# Import sync_to_async from Django Channels
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from database.models import KhojUser, GoogleUser, NotionConfig
|
||||||
|
|
||||||
|
ModelType = TypeVar("ModelType", bound=models.Model)
|
||||||
|
|
||||||
|
|
||||||
|
async def retrieve_object(model_class: Type[ModelType], id: int) -> ModelType:
|
||||||
|
instance = await model_class.objects.filter(id=id).afirst()
|
||||||
|
if not instance:
|
||||||
|
raise HTTPException(status_code=404, detail=f"{model_class.__name__} not found")
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
async def set_notion_config(token: str, user: KhojUser):
|
||||||
|
notion_config = await NotionConfig.objects.filter(user=user).afirst()
|
||||||
|
if not notion_config:
|
||||||
|
notion_config = await NotionConfig.objects.acreate(token=token, user=user)
|
||||||
|
else:
|
||||||
|
notion_config.token = token
|
||||||
|
await notion_config.asave()
|
||||||
|
return notion_config
|
||||||
|
|
||||||
|
|
||||||
|
async def get_or_create_user(token: dict) -> KhojUser:
|
||||||
|
user = await get_user_by_token(token)
|
||||||
|
if not user:
|
||||||
|
user = await create_google_user(token)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def create_google_user(token: dict) -> KhojUser:
|
||||||
|
user_info = token.get("userinfo")
|
||||||
|
user = await KhojUser.objects.acreate(
|
||||||
|
username=user_info.get("email"), email=user_info.get("email"), uuid=uuid.uuid4()
|
||||||
|
)
|
||||||
|
await user.asave()
|
||||||
|
await GoogleUser.objects.acreate(
|
||||||
|
sub=user_info.get("sub"),
|
||||||
|
azp=user_info.get("azp"),
|
||||||
|
email=user_info.get("email"),
|
||||||
|
name=user_info.get("name"),
|
||||||
|
given_name=user_info.get("given_name"),
|
||||||
|
family_name=user_info.get("family_name"),
|
||||||
|
picture=user_info.get("picture"),
|
||||||
|
locale=user_info.get("locale"),
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_by_token(token: dict) -> KhojUser:
|
||||||
|
user_info = token.get("userinfo")
|
||||||
|
google_user = await GoogleUser.objects.filter(sub=user_info.get("sub")).select_related("user").afirst()
|
||||||
|
if not google_user:
|
||||||
|
return None
|
||||||
|
return google_user.user
|
||||||
|
|
||||||
|
|
||||||
|
async def retrieve_user(session_id: str) -> KhojUser:
|
||||||
|
session = SessionStore(session_key=session_id)
|
||||||
|
if not await sync_to_async(session.exists)(session_key=session_id):
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid session")
|
||||||
|
session_data = await sync_to_async(session.load)()
|
||||||
|
user = await KhojUser.objects.filter(id=session_data.get("_auth_user_id")).afirst()
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid user")
|
||||||
|
return user
|
8
src/database/admin.py
Normal file
8
src/database/admin.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
from django.contrib import admin
|
||||||
|
from django.contrib.auth.admin import UserAdmin
|
||||||
|
|
||||||
|
# Register your models here.
|
||||||
|
|
||||||
|
from database.models import KhojUser
|
||||||
|
|
||||||
|
admin.site.register(KhojUser, UserAdmin)
|
6
src/database/apps.py
Normal file
6
src/database/apps.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from django.apps import AppConfig
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseConfig(AppConfig):
|
||||||
|
default_auto_field = "django.db.models.BigAutoField"
|
||||||
|
name = "database"
|
98
src/database/migrations/0001_khojuser.py
Normal file
98
src/database/migrations/0001_khojuser.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
# Generated by Django 4.2.5 on 2023-09-14 19:00
|
||||||
|
|
||||||
|
import django.contrib.auth.models
|
||||||
|
import django.contrib.auth.validators
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.utils.timezone
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
initial = True
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("auth", "0012_alter_user_first_name_max_length"),
|
||||||
|
]
|
||||||
|
|
||||||
|
run_before = [
|
||||||
|
("admin", "0001_initial"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="KhojUser",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("password", models.CharField(max_length=128, verbose_name="password")),
|
||||||
|
("last_login", models.DateTimeField(blank=True, null=True, verbose_name="last login")),
|
||||||
|
(
|
||||||
|
"is_superuser",
|
||||||
|
models.BooleanField(
|
||||||
|
default=False,
|
||||||
|
help_text="Designates that this user has all permissions without explicitly assigning them.",
|
||||||
|
verbose_name="superuser status",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"username",
|
||||||
|
models.CharField(
|
||||||
|
error_messages={"unique": "A user with that username already exists."},
|
||||||
|
help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.",
|
||||||
|
max_length=150,
|
||||||
|
unique=True,
|
||||||
|
validators=[django.contrib.auth.validators.UnicodeUsernameValidator()],
|
||||||
|
verbose_name="username",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("first_name", models.CharField(blank=True, max_length=150, verbose_name="first name")),
|
||||||
|
("last_name", models.CharField(blank=True, max_length=150, verbose_name="last name")),
|
||||||
|
("email", models.EmailField(blank=True, max_length=254, verbose_name="email address")),
|
||||||
|
(
|
||||||
|
"is_staff",
|
||||||
|
models.BooleanField(
|
||||||
|
default=False,
|
||||||
|
help_text="Designates whether the user can log into this admin site.",
|
||||||
|
verbose_name="staff status",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"is_active",
|
||||||
|
models.BooleanField(
|
||||||
|
default=True,
|
||||||
|
help_text="Designates whether this user should be treated as active. Unselect this instead of deleting accounts.",
|
||||||
|
verbose_name="active",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("date_joined", models.DateTimeField(default=django.utils.timezone.now, verbose_name="date joined")),
|
||||||
|
(
|
||||||
|
"groups",
|
||||||
|
models.ManyToManyField(
|
||||||
|
blank=True,
|
||||||
|
help_text="The groups this user belongs to. A user will get all permissions granted to each of their groups.",
|
||||||
|
related_name="user_set",
|
||||||
|
related_query_name="user",
|
||||||
|
to="auth.group",
|
||||||
|
verbose_name="groups",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"user_permissions",
|
||||||
|
models.ManyToManyField(
|
||||||
|
blank=True,
|
||||||
|
help_text="Specific permissions for this user.",
|
||||||
|
related_name="user_set",
|
||||||
|
related_query_name="user",
|
||||||
|
to="auth.permission",
|
||||||
|
verbose_name="user permissions",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"verbose_name": "user",
|
||||||
|
"verbose_name_plural": "users",
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
managers=[
|
||||||
|
("objects", django.contrib.auth.models.UserManager()),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
32
src/database/migrations/0002_googleuser.py
Normal file
32
src/database/migrations/0002_googleuser.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
# Generated by Django 4.2.4 on 2023-09-18 23:24
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.db.models.deletion
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0001_khojuser"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="GoogleUser",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("sub", models.CharField(max_length=200)),
|
||||||
|
("azp", models.CharField(max_length=200)),
|
||||||
|
("email", models.CharField(max_length=200)),
|
||||||
|
("name", models.CharField(max_length=200)),
|
||||||
|
("given_name", models.CharField(max_length=200)),
|
||||||
|
("family_name", models.CharField(max_length=200)),
|
||||||
|
("picture", models.CharField(max_length=200)),
|
||||||
|
("locale", models.CharField(max_length=200)),
|
||||||
|
(
|
||||||
|
"user",
|
||||||
|
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
|
@ -0,0 +1,79 @@
|
||||||
|
# Generated by Django 4.2.5 on 2023-09-27 17:52
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.db.models.deletion
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0002_googleuser"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="Configuration",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="ConversationProcessorConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("conversation", models.JSONField()),
|
||||||
|
("enable_offline_chat", models.BooleanField(default=False)),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="GithubConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("pat_token", models.CharField(max_length=200)),
|
||||||
|
("compressed_jsonl", models.CharField(max_length=300)),
|
||||||
|
("embeddings_file", models.CharField(max_length=300)),
|
||||||
|
(
|
||||||
|
"config",
|
||||||
|
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="khojuser",
|
||||||
|
name="uuid",
|
||||||
|
field=models.UUIDField(verbose_name=models.UUIDField(default=uuid.uuid4, editable=False)),
|
||||||
|
preserve_default=False,
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="NotionConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("token", models.CharField(max_length=200)),
|
||||||
|
("compressed_jsonl", models.CharField(max_length=300)),
|
||||||
|
("embeddings_file", models.CharField(max_length=300)),
|
||||||
|
(
|
||||||
|
"config",
|
||||||
|
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="GithubRepoConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("name", models.CharField(max_length=200)),
|
||||||
|
("owner", models.CharField(max_length=200)),
|
||||||
|
("branch", models.CharField(max_length=200)),
|
||||||
|
(
|
||||||
|
"github_config",
|
||||||
|
models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.githubconfig"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="configuration",
|
||||||
|
name="user",
|
||||||
|
field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
|
||||||
|
),
|
||||||
|
]
|
0
src/database/migrations/__init__.py
Normal file
0
src/database/migrations/__init__.py
Normal file
53
src/database/models/__init__.py
Normal file
53
src/database/models/__init__.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from django.db import models
|
||||||
|
from django.contrib.auth.models import AbstractUser
|
||||||
|
|
||||||
|
|
||||||
|
class KhojUser(AbstractUser):
|
||||||
|
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleUser(models.Model):
|
||||||
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||||
|
sub = models.CharField(max_length=200)
|
||||||
|
azp = models.CharField(max_length=200)
|
||||||
|
email = models.CharField(max_length=200)
|
||||||
|
name = models.CharField(max_length=200)
|
||||||
|
given_name = models.CharField(max_length=200)
|
||||||
|
family_name = models.CharField(max_length=200)
|
||||||
|
picture = models.CharField(max_length=200)
|
||||||
|
locale = models.CharField(max_length=200)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
class Configuration(models.Model):
|
||||||
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class NotionConfig(models.Model):
|
||||||
|
token = models.CharField(max_length=200)
|
||||||
|
compressed_jsonl = models.CharField(max_length=300)
|
||||||
|
embeddings_file = models.CharField(max_length=300)
|
||||||
|
config = models.OneToOneField(Configuration, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class GithubConfig(models.Model):
|
||||||
|
pat_token = models.CharField(max_length=200)
|
||||||
|
compressed_jsonl = models.CharField(max_length=300)
|
||||||
|
embeddings_file = models.CharField(max_length=300)
|
||||||
|
config = models.OneToOneField(Configuration, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class GithubRepoConfig(models.Model):
|
||||||
|
name = models.CharField(max_length=200)
|
||||||
|
owner = models.CharField(max_length=200)
|
||||||
|
branch = models.CharField(max_length=200)
|
||||||
|
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationProcessorConfig(models.Model):
|
||||||
|
conversation = models.JSONField()
|
||||||
|
enable_offline_chat = models.BooleanField(default=False)
|
|
@ -5,10 +5,19 @@ import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import requests
|
import requests
|
||||||
|
import os
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import schedule
|
import schedule
|
||||||
from fastapi.staticfiles import StaticFiles
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||||
|
|
||||||
|
from starlette.authentication import (
|
||||||
|
AuthCredentials,
|
||||||
|
AuthenticationBackend,
|
||||||
|
SimpleUser,
|
||||||
|
UnauthenticatedUser,
|
||||||
|
)
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
|
@ -26,8 +35,32 @@ from khoj.routers.indexer import configure_content, load_content, configure_sear
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def initialize_server(config: Optional[FullConfig], required=False):
|
class AuthenticatedKhojUser(SimpleUser):
|
||||||
if config is None and required:
|
def __init__(self, user):
|
||||||
|
self.object = user
|
||||||
|
super().__init__(user.email)
|
||||||
|
|
||||||
|
|
||||||
|
class UserAuthenticationBackend(AuthenticationBackend):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
from database.models import KhojUser
|
||||||
|
|
||||||
|
self.khojuser_manager = KhojUser.objects
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def authenticate(self, request):
|
||||||
|
current_user = request.session.get("user")
|
||||||
|
if current_user and current_user.get("email"):
|
||||||
|
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
|
||||||
|
if user:
|
||||||
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||||
|
return AuthCredentials(), UnauthenticatedUser()
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_server(config: Optional[FullConfig]):
|
||||||
|
if config is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"🚨 Exiting as Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config or by editing {state.config_file}."
|
f"🚨 Exiting as Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config or by editing {state.config_file}."
|
||||||
)
|
)
|
||||||
|
@ -99,12 +132,18 @@ def configure_routes(app):
|
||||||
from khoj.routers.api_beta import api_beta
|
from khoj.routers.api_beta import api_beta
|
||||||
from khoj.routers.web_client import web_client
|
from khoj.routers.web_client import web_client
|
||||||
from khoj.routers.indexer import indexer
|
from khoj.routers.indexer import indexer
|
||||||
|
from khoj.routers.auth import auth_router
|
||||||
|
|
||||||
app.mount("/static", StaticFiles(directory=constants.web_directory), name="static")
|
|
||||||
app.include_router(api, prefix="/api")
|
app.include_router(api, prefix="/api")
|
||||||
app.include_router(api_beta, prefix="/api/beta")
|
app.include_router(api_beta, prefix="/api/beta")
|
||||||
app.include_router(indexer, prefix="/v1/indexer")
|
app.include_router(indexer, prefix="/v1/indexer")
|
||||||
app.include_router(web_client)
|
app.include_router(web_client)
|
||||||
|
app.include_router(auth_router, prefix="/auth")
|
||||||
|
|
||||||
|
|
||||||
|
def configure_middleware(app):
|
||||||
|
app.add_middleware(AuthenticationMiddleware, backend=UserAuthenticationBackend())
|
||||||
|
app.add_middleware(SessionMiddleware, secret_key=os.environ.get("KHOJ_DJANGO_SECRET_KEY", "!secret"))
|
||||||
|
|
||||||
|
|
||||||
if not state.demo:
|
if not state.demo:
|
||||||
|
|
|
@ -170,7 +170,11 @@
|
||||||
|
|
||||||
// Execute Search and Render Results
|
// Execute Search and Render Results
|
||||||
url = createRequestUrl(query, type, results_count || 5, rerank);
|
url = createRequestUrl(query, type, results_count || 5, rerank);
|
||||||
fetch(url)
|
fetch(url, {
|
||||||
|
headers: {
|
||||||
|
"X-CSRFToken": csrfToken
|
||||||
|
}
|
||||||
|
})
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
console.log(data);
|
console.log(data);
|
||||||
|
|
59
src/khoj/routers/auth.py
Normal file
59
src/khoj/routers/auth.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from starlette.config import Config
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import HTMLResponse, RedirectResponse
|
||||||
|
from authlib.integrations.starlette_client import OAuth, OAuthError
|
||||||
|
|
||||||
|
from database.adapters import get_or_create_user
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
auth_router = APIRouter()
|
||||||
|
|
||||||
|
if not os.environ.get("GOOGLE_CLIENT_ID") or not os.environ.get("GOOGLE_CLIENT_SECRET"):
|
||||||
|
logger.info("Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET environment variables to use Google OAuth")
|
||||||
|
else:
|
||||||
|
config = Config(environ=os.environ)
|
||||||
|
|
||||||
|
oauth = OAuth(config)
|
||||||
|
|
||||||
|
CONF_URL = "https://accounts.google.com/.well-known/openid-configuration"
|
||||||
|
oauth.register(name="google", server_metadata_url=CONF_URL, client_kwargs={"scope": "openid email profile"})
|
||||||
|
|
||||||
|
|
||||||
|
@auth_router.get("/")
|
||||||
|
async def homepage(request: Request):
|
||||||
|
user = request.session.get("user")
|
||||||
|
if user:
|
||||||
|
data = json.dumps(user)
|
||||||
|
html = f"<pre>{data}</pre>" '<a href="/logout">logout</a>'
|
||||||
|
return HTMLResponse(html)
|
||||||
|
return HTMLResponse('<a href="/login">login</a>')
|
||||||
|
|
||||||
|
|
||||||
|
@auth_router.get("/login")
|
||||||
|
async def login(request: Request):
|
||||||
|
redirect_uri = request.url_for("auth")
|
||||||
|
return await oauth.google.authorize_redirect(request, redirect_uri)
|
||||||
|
|
||||||
|
|
||||||
|
@auth_router.get("/redirect")
|
||||||
|
async def auth(request: Request):
|
||||||
|
try:
|
||||||
|
token = await oauth.google.authorize_access_token(request)
|
||||||
|
except OAuthError as error:
|
||||||
|
return HTMLResponse(f"<h1>{error.error}</h1>")
|
||||||
|
khoj_user = await get_or_create_user(token)
|
||||||
|
user = token.get("userinfo")
|
||||||
|
if user:
|
||||||
|
request.session["user"] = dict(user)
|
||||||
|
return RedirectResponse(url="/")
|
||||||
|
|
||||||
|
|
||||||
|
@auth_router.get("/logout")
|
||||||
|
async def logout(request: Request):
|
||||||
|
request.session.pop("user", None)
|
||||||
|
return RedirectResponse(url="/")
|
|
@ -13,6 +13,10 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def collect_files(config: ContentConfig, search_type: Optional[SearchType] = SearchType.All):
|
def collect_files(config: ContentConfig, search_type: Optional[SearchType] = SearchType.All):
|
||||||
files = {}
|
files = {}
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
return files
|
||||||
|
|
||||||
if search_type == SearchType.All or search_type == SearchType.Org:
|
if search_type == SearchType.All or search_type == SearchType.Org:
|
||||||
files["org"] = get_org_files(config.org) if config.org else {}
|
files["org"] = get_org_files(config.org) if config.org else {}
|
||||||
if search_type == SearchType.All or search_type == SearchType.Markdown:
|
if search_type == SearchType.All or search_type == SearchType.Markdown:
|
||||||
|
|
22
src/manage.py
Executable file
22
src/manage.py
Executable file
|
@ -0,0 +1,22 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
"""Django's command-line utility for administrative tasks."""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run administrative tasks."""
|
||||||
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
|
||||||
|
try:
|
||||||
|
from django.core.management import execute_from_command_line
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Couldn't import Django. Are you sure it's installed and "
|
||||||
|
"available on your PYTHONPATH environment variable? Did you "
|
||||||
|
"forget to activate a virtual environment?"
|
||||||
|
) from exc
|
||||||
|
execute_from_command_line(sys.argv)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -4,14 +4,16 @@ from copy import deepcopy
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.main import app
|
from app.main import app
|
||||||
from khoj.configure import configure_processor, configure_routes, configure_search_types
|
from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware
|
||||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
from khoj.utils.config import SearchModels
|
from khoj.utils.config import SearchModels
|
||||||
|
from khoj.utils.constants import web_directory
|
||||||
from khoj.utils.helpers import resolve_absolute_path
|
from khoj.utils.helpers import resolve_absolute_path
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
ContentConfig,
|
ContentConfig,
|
||||||
|
@ -231,6 +233,8 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
|
||||||
state.processor_config = configure_processor(processor_config)
|
state.processor_config = configure_processor(processor_config)
|
||||||
|
|
||||||
configure_routes(app)
|
configure_routes(app)
|
||||||
|
configure_middleware(app)
|
||||||
|
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@ -264,6 +268,8 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor
|
||||||
state.processor_config = configure_processor(processor_config)
|
state.processor_config = configure_processor(processor_config)
|
||||||
|
|
||||||
configure_routes(app)
|
configure_routes(app)
|
||||||
|
configure_middleware(app)
|
||||||
|
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@ -292,6 +298,8 @@ def client_offline_chat(
|
||||||
state.processor_config = configure_processor(processor_config_offline_chat)
|
state.processor_config = configure_processor(processor_config_offline_chat)
|
||||||
|
|
||||||
configure_routes(app)
|
configure_routes(app)
|
||||||
|
configure_middleware(app)
|
||||||
|
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from urllib.parse import quote
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.main import app
|
from app.main import app
|
||||||
from khoj.configure import configure_routes, configure_search_types
|
from khoj.configure import configure_routes, configure_search_types
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.state import search_models, content_index, config
|
from khoj.utils.state import search_models, content_index, config
|
||||||
|
|
Loading…
Reference in a new issue