Move production dependencies to prod python packages group

This will reduce khoj dependencies to install for self-hosting users

- Move auth production dependencies to prod python packages group
  - Only enable authentication API router if not in anonymous mode
  - Improve error with requirements to enable authentication when not in
    anonymous mode
This commit is contained in:
Debanjum Singh Solanky 2024-02-14 15:20:27 +05:30
parent d7dbb715ef
commit cf4a524988
6 changed files with 47 additions and 25 deletions

View file

@ -13,7 +13,7 @@ COPY pyproject.toml .
COPY README.md . COPY README.md .
ARG VERSION=0.0.0 ARG VERSION=0.0.0
RUN sed -i "s/dynamic = \\[\"version\"\\]/version = \"$VERSION\"/" pyproject.toml && \ RUN sed -i "s/dynamic = \\[\"version\"\\]/version = \"$VERSION\"/" pyproject.toml && \
TMPDIR=/home/cache/ pip install --cache-dir=/home/cache/ -e . TMPDIR=/home/cache/ pip install --cache-dir=/home/cache/ -e .[prod]
# Copy Source Code # Copy Source Code
COPY . . COPY . .

View file

@ -69,17 +69,13 @@ dependencies = [
"httpx == 0.25.0", "httpx == 0.25.0",
"pgvector == 0.2.4", "pgvector == 0.2.4",
"psycopg2-binary == 2.9.9", "psycopg2-binary == 2.9.9",
"google-auth == 2.23.3",
"python-multipart == 0.0.6",
"gunicorn == 21.2.0", "gunicorn == 21.2.0",
"lxml == 4.9.3", "lxml == 4.9.3",
"tzdata == 2023.3", "tzdata == 2023.3",
"rapidocr-onnxruntime == 1.3.8", "rapidocr-onnxruntime == 1.3.8",
"stripe == 7.3.0",
"openai-whisper >= 20231117", "openai-whisper >= 20231117",
"django-phonenumber-field == 7.3.0", "django-phonenumber-field == 7.3.0",
"phonenumbers == 8.13.27", "phonenumbers == 8.13.27",
"twilio == 8.11"
] ]
dynamic = ["version"] dynamic = ["version"]
@ -93,6 +89,11 @@ Releases = "https://github.com/khoj-ai/khoj/releases"
khoj = "khoj.main:run" khoj = "khoj.main:run"
[project.optional-dependencies] [project.optional-dependencies]
prod = [
"google-auth == 2.23.3",
"stripe == 7.3.0",
"twilio == 8.11",
]
test = [ test = [
"pytest >= 7.1.2", "pytest >= 7.1.2",
"freezegun >= 1.2.0", "freezegun >= 1.2.0",
@ -103,6 +104,7 @@ test = [
] ]
dev = [ dev = [
"khoj-assistant[test]", "khoj-assistant[test]",
"khoj-assistant[prod]",
"mypy >= 1.0.1", "mypy >= 1.0.1",
"black >= 23.1.0", "black >= 23.1.0",
"pre-commit >= 3.0.4", "pre-commit >= 3.0.4",

View file

@ -73,6 +73,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
Subscription.objects.create(user=default_user, type="standard", renewal_date=renewal_date) Subscription.objects.create(user=default_user, type="standard", renewal_date=renewal_date)
async def authenticate(self, request: HTTPConnection): async def authenticate(self, request: HTTPConnection):
# Request from Web client
current_user = request.session.get("user") current_user = request.session.get("user")
if current_user and current_user.get("email"): if current_user and current_user.get("email"):
user = ( user = (
@ -93,6 +94,8 @@ class UserAuthenticationBackend(AuthenticationBackend):
if subscribed: if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
# Request from Desktop, Emacs, Obsidian clients
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
# Get bearer token from header # Get bearer token from header
bearer_token = request.headers["Authorization"].split("Bearer ")[1] bearer_token = request.headers["Authorization"].split("Bearer ")[1]
@ -116,7 +119,8 @@ class UserAuthenticationBackend(AuthenticationBackend):
if subscribed: if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
# Get query params for client_id and client_secret
# Request from Whatsapp client
client_id = request.query_params.get("client_id") client_id = request.query_params.get("client_id")
if client_id: if client_id:
# Get the client secret, which is passed in the Authorization header # Get the client secret, which is passed in the Authorization header
@ -163,6 +167,8 @@ class UserAuthenticationBackend(AuthenticationBackend):
AuthenticatedKhojUser(user, client_application), AuthenticatedKhojUser(user, client_application),
) )
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application)
# No auth required if server in anonymous mode
if state.anonymous_mode: if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user: if user:
@ -258,28 +264,32 @@ def configure_routes(app):
from khoj.routers.api import api from khoj.routers.api import api
from khoj.routers.api_chat import api_chat from khoj.routers.api_chat import api_chat
from khoj.routers.api_config import api_config from khoj.routers.api_config import api_config
from khoj.routers.auth import auth_router
from khoj.routers.indexer import indexer from khoj.routers.indexer import indexer
from khoj.routers.web_client import web_client from khoj.routers.web_client import web_client
app.include_router(api, prefix="/api") app.include_router(api, prefix="/api")
app.include_router(api_chat, prefix="/api/chat")
app.include_router(api_config, prefix="/api/config") app.include_router(api_config, prefix="/api/config")
app.include_router(indexer, prefix="/api/v1/index") app.include_router(indexer, prefix="/api/v1/index")
app.include_router(web_client) app.include_router(web_client)
app.include_router(auth_router, prefix="/auth")
app.include_router(api_chat, prefix="/api/chat") if not state.anonymous_mode:
from khoj.routers.auth import auth_router
app.include_router(auth_router, prefix="/auth")
logger.info("🔑 Enabled Authentication")
if state.billing_enabled: if state.billing_enabled:
from khoj.routers.subscription import subscription_router from khoj.routers.subscription import subscription_router
logger.info("💳 Enabled Billing")
app.include_router(subscription_router, prefix="/api/subscription") app.include_router(subscription_router, prefix="/api/subscription")
logger.info("💳 Enabled Billing")
if is_twilio_enabled(): if is_twilio_enabled():
logger.info("📞 Enabled Twilio")
from khoj.routers.api_phone import api_phone from khoj.routers.api_phone import api_phone
app.include_router(api_phone, prefix="/api/config/phone") app.include_router(api_phone, prefix="/api/config/phone")
logger.info("📞 Enabled Twilio")
def configure_middleware(app): def configure_middleware(app):

View file

@ -2,10 +2,7 @@ import logging
import os import os
from typing import Optional from typing import Optional
from authlib.integrations.starlette_client import OAuth, OAuthError
from fastapi import APIRouter from fastapi import APIRouter
from google.auth.transport import requests as google_requests
from google.oauth2 import id_token
from starlette.authentication import requires from starlette.authentication import requires
from starlette.config import Config from starlette.config import Config
from starlette.requests import Request from starlette.requests import Request
@ -17,7 +14,6 @@ from khoj.database.adapters import (
get_khoj_tokens, get_khoj_tokens,
get_or_create_user, get_or_create_user,
) )
from khoj.database.models import KhojApiUser
from khoj.routers.helpers import update_telemetry_state from khoj.routers.helpers import update_telemetry_state
from khoj.utils import state from khoj.utils import state
@ -25,11 +21,23 @@ logger = logging.getLogger(__name__)
auth_router = APIRouter() auth_router = APIRouter()
if not state.anonymous_mode and not (os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET")):
logger.warning( if not state.anonymous_mode:
"🚨 Use --anonymous-mode flag to disable Google OAuth or set GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET environment variables to enable it" missing_requirements = []
) from authlib.integrations.starlette_client import OAuth, OAuthError
else:
try:
from google.auth.transport import requests as google_requests
from google.oauth2 import id_token
except ImportError:
missing_requirements += ["Install the Khoj production package with `pip install khoj-assistant[prod]`"]
if not os.environ.get("GOOGLE_CLIENT_ID") or not os.environ.get("GOOGLE_CLIENT_SECRET"):
missing_requirements += ["Set your GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET as environment variables"]
if missing_requirements:
requirements_string = "\n - " + "\n - ".join(missing_requirements)
error_msg = f"🚨 Start Khoj with --anonymous-mode flag or to enable authentication:{requirements_string}"
logger.error(error_msg)
config = Config(environ=os.environ) config = Config(environ=os.environ)
oauth = OAuth(config) oauth = OAuth(config)

View file

@ -2,16 +2,18 @@ import logging
import os import os
from datetime import datetime, timezone from datetime import datetime, timezone
import stripe
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from fastapi.responses import Response
from starlette.authentication import requires from starlette.authentication import requires
from khoj.database import adapters from khoj.database import adapters
from khoj.utils import state
# Stripe integration for Khoj Cloud Subscription # Stripe integration for Khoj Cloud Subscription
stripe.api_key = os.getenv("STRIPE_API_KEY") if state.billing_enabled:
import stripe
stripe.api_key = os.getenv("STRIPE_API_KEY")
endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET") endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
subscription_router = APIRouter() subscription_router = APIRouter()

View file

@ -1,8 +1,6 @@
import logging import logging
import os import os
from twilio.rest import Client
from khoj.database.models import KhojUser from khoj.database.models import KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -13,6 +11,8 @@ verification_service_sid = os.getenv("TWILIO_VERIFICATION_SID")
twilio_enabled = account_sid is not None and auth_token is not None and verification_service_sid is not None twilio_enabled = account_sid is not None and auth_token is not None and verification_service_sid is not None
if twilio_enabled: if twilio_enabled:
from twilio.rest import Client
client = Client(account_sid, auth_token) client = Client(account_sid, auth_token)