Add support for a first-party client app to call into Khoj (Part 1) (#601)

* Add support for a first party client app
- Based on a client id and client secret, allow a first party app to call into the Khoj backend with a phone number identifier
- Add migration to add phone numbers to the KhojUser object
* Add plus in front of country code when registering a phone number.
- Decrease free tier limit to 5 (from 10)
- Return a response object when handling stripe webhooks
* Fix telemetry method which references authenticated user's client app
* Add better error handling for null phone numbers, simplify logic of authenticating user
* Pull the client_secret in the API call from the authorization header
* Add a migration merge to resolve phone number and other changes
This commit is contained in:
sabaimran 2024-01-18 05:54:14 -08:00 committed by GitHub
parent 9dfe1bb003
commit 039ed78253
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 204 additions and 18 deletions

View file

@ -76,6 +76,8 @@ dependencies = [
"rapidocr-onnxruntime == 1.3.8",
"stripe == 7.3.0",
"openai-whisper >= 20231117",
"django-phonenumber-field == 7.3.0",
"phonenumbers == 8.13.27",
]
dynamic = ["version"]

View file

@ -62,6 +62,7 @@ INSTALLED_APPS = [
"django.contrib.sessions",
"django.contrib.messages",
"django.contrib.staticfiles",
"phonenumber_field",
]
MIDDLEWARE = [

View file

@ -9,6 +9,7 @@ import openai
import requests
import schedule
from django.utils.timezone import make_aware
from fastapi import Response
from starlette.authentication import (
AuthCredentials,
AuthenticationBackend,
@ -20,27 +21,32 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.requests import HTTPConnection
from khoj.database.adapters import (
ClientApplicationAdapters,
ConversationAdapters,
SubscriptionState,
aget_or_create_user_by_phone_number,
aget_user_by_phone_number,
aget_user_subscription_state,
get_all_users,
get_or_create_search_models,
)
from khoj.database.models import KhojUser, Subscription
from khoj.database.models import ClientApplication, KhojUser, Subscription
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, configure_search, load_content
from khoj.utils import constants, state
from khoj.utils.config import SearchType
from khoj.utils.fs_syncer import collect_files
from khoj.utils.helpers import is_none_or_empty
from khoj.utils.rawconfig import FullConfig
logger = logging.getLogger(__name__)
class AuthenticatedKhojUser(SimpleUser):
def __init__(self, user):
def __init__(self, user, client_app: Optional[ClientApplication] = None):
self.object = user
super().__init__(user.email)
self.client_app = client_app
super().__init__(user.username)
class UserAuthenticationBackend(AuthenticationBackend):
@ -108,6 +114,53 @@ class UserAuthenticationBackend(AuthenticationBackend):
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
# Get query params for client_id and client_secret
client_id = request.query_params.get("client_id")
if client_id:
# Get the client secret, which is passed in the Authorization header
client_secret = request.headers["Authorization"].split("Bearer ")[1]
if not client_secret:
return Response(
status_code=401,
content="Please provide a client secret in the Authorization header with a client_id query param.",
)
# Get the client application
client_application = await ClientApplicationAdapters.aget_client_application_by_id(client_id, client_secret)
if client_application is None:
return AuthCredentials(), UnauthenticatedUser()
# Get the identifier used for the user
phone_number = request.query_params.get("phone_number")
if is_none_or_empty(phone_number):
return AuthCredentials(), UnauthenticatedUser()
if not phone_number.startswith("+"):
phone_number = f"+{phone_number}"
create_if_not_exists = request.query_params.get("create_if_not_exists")
if create_if_not_exists:
user = await aget_or_create_user_by_phone_number(phone_number)
else:
user = await aget_user_by_phone_number(phone_number)
if user is None:
return AuthCredentials(), UnauthenticatedUser()
if not state.billing_enabled:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user, client_application)
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return (
AuthCredentials(["authenticated", "premium"]),
AuthenticatedKhojUser(user),
)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application)
if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user:

View file

@ -17,6 +17,7 @@ from torch import Tensor
from khoj.database.models import (
ChatModelOptions,
ClientApplication,
Conversation,
Entry,
GithubConfig,
@ -40,7 +41,7 @@ from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import generate_random_name
from khoj.utils.helpers import generate_random_name, is_none_or_empty
class SubscriptionState(Enum):
@ -85,6 +86,28 @@ async def get_or_create_user(token: dict) -> KhojUser:
return user
async def aget_or_create_user_by_phone_number(phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number):
return None
user = await aget_user_by_phone_number(phone_number)
if not user:
user = await acreate_user_by_phone_number(phone_number)
return user
async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number):
return None
user, _ = await KhojUser.objects.filter(phone_number=phone_number).aupdate_or_create(
defaults={"username": phone_number, "phone_number": phone_number}
)
await user.asave()
await Subscription.objects.acreate(user=user, type="trial")
return user
async def get_or_create_user_by_email(email: str) -> KhojUser:
user, _ = await KhojUser.objects.filter(email=email).aupdate_or_create(defaults={"username": email, "email": email})
await user.asave()
@ -187,6 +210,12 @@ async def get_user_by_token(token: dict) -> KhojUser:
return google_user.user
async def aget_user_by_phone_number(phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number):
return None
return await KhojUser.objects.filter(phone_number=phone_number).prefetch_related("subscription").afirst()
async def retrieve_user(session_id: str) -> KhojUser:
session = SessionStore(session_key=session_id)
if not await sync_to_async(session.exists)(session_key=session_id):
@ -270,6 +299,12 @@ async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
return new_config
class ClientApplicationAdapters:
@staticmethod
async def aget_client_application_by_id(client_id: str, client_secret: str):
return await ClientApplication.objects.filter(client_id=client_id, client_secret=client_secret).afirst()
class ConversationAdapters:
@staticmethod
def get_conversation_by_user(user: KhojUser):
@ -279,11 +314,11 @@ class ConversationAdapters:
return Conversation.objects.create(user=user)
@staticmethod
async def aget_conversation_by_user(user: KhojUser):
conversation = Conversation.objects.filter(user=user)
async def aget_conversation_by_user(user: KhojUser, client_application: ClientApplication = None):
conversation = Conversation.objects.filter(user=user, client=client_application)
if await conversation.aexists():
return await conversation.afirst()
return await Conversation.objects.acreate(user=user)
return await Conversation.objects.acreate(user=user, client=client_application)
@staticmethod
async def adelete_conversation_by_user(user: KhojUser):

View file

@ -7,6 +7,7 @@ from django.http import HttpResponse
from khoj.database.models import (
ChatModelOptions,
ClientApplication,
Conversation,
KhojUser,
OfflineChatProcessorConversationConfig,
@ -19,10 +20,24 @@ from khoj.database.models import (
UserSearchModelConfig,
)
# Register your models here.
class KhojUserAdmin(UserAdmin):
list_display = (
"id",
"email",
"username",
"is_active",
"is_staff",
"is_superuser",
"phone_number",
)
search_fields = ("email", "username", "phone_number")
filter_horizontal = ("groups", "user_permissions")
fieldsets = (("Personal info", {"fields": ("phone_number",)}),) + UserAdmin.fieldsets
admin.site.register(KhojUser, UserAdmin)
admin.site.register(KhojUser, KhojUserAdmin)
admin.site.register(ChatModelOptions)
admin.site.register(SpeechToTextModelOptions)
@ -33,6 +48,7 @@ admin.site.register(Subscription)
admin.site.register(ReflectiveQuestion)
admin.site.register(UserSearchModelConfig)
admin.site.register(TextToImageModelConfig)
admin.site.register(ClientApplication)
@admin.register(Conversation)

View file

@ -0,0 +1,46 @@
# Generated by Django 4.2.7 on 2024-01-04 12:22
import django.db.models.deletion
import phonenumber_field.modelfields
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0024_alter_entry_embeddings"),
]
operations = [
migrations.CreateModel(
name="ClientApplication",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(max_length=200)),
("client_id", models.CharField(max_length=200)),
("client_secret", models.CharField(max_length=200)),
],
options={
"abstract": False,
},
),
migrations.AddField(
model_name="khojuser",
name="phone_number",
field=phonenumber_field.modelfields.PhoneNumberField(
blank=True, default=None, max_length=128, null=True, region=None
),
),
migrations.AddField(
model_name="conversation",
name="client",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="database.clientapplication",
),
),
]

View file

@ -0,0 +1,13 @@
# Generated by Django 4.2.7 on 2024-01-18 13:24
from typing import List
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0025_clientapplication_khojuser_phone_number_and_more"),
("database", "0026_searchmodelconfig_cross_encoder_inference_endpoint_and_more"),
]
operations: List[str] = []

View file

@ -3,6 +3,7 @@ import uuid
from django.contrib.auth.models import AbstractUser
from django.db import models
from pgvector.django import VectorField
from phonenumber_field.modelfields import PhoneNumberField
class BaseModel(models.Model):
@ -13,8 +14,18 @@ class BaseModel(models.Model):
abstract = True
class ClientApplication(BaseModel):
name = models.CharField(max_length=200)
client_id = models.CharField(max_length=200)
client_secret = models.CharField(max_length=200)
def __str__(self):
return self.name
class KhojUser(AbstractUser):
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
phone_number = PhoneNumberField(null=True, default=None, blank=True)
def save(self, *args, **kwargs):
if not self.uuid:
@ -165,6 +176,7 @@ class UserSearchModelConfig(BaseModel):
class Conversation(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
class ReflectiveQuestion(BaseModel):

View file

@ -359,8 +359,8 @@ async def chat(
n: Optional[int] = 5,
d: Optional[float] = 0.18,
stream: Optional[bool] = False,
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24)),
) -> Response:
user: KhojUser = request.user.object
q = unquote(q)
@ -372,7 +372,7 @@ async def chat(
q = q.replace(f"/{conversation_command.value}", "").strip()
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
meta_log = (await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app)).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
@ -392,7 +392,11 @@ async def chat(
elif conversation_command == ConversationCommand.Notes and not await EntryAdapters.auser_has_entries(user):
no_entries_found_format = no_entries_found.format()
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
if stream:
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
else:
response_obj = {"response": no_entries_found_format}
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
elif conversation_command == ConversationCommand.Online:
try:

View file

@ -13,7 +13,12 @@ from fastapi import Depends, Header, HTTPException, Request, UploadFile
from starlette.authentication import has_required_scope
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig
from khoj.database.models import (
ClientApplication,
KhojUser,
Subscription,
TextToImageModelConfig,
)
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.chat_model import (
converse_offline,
@ -74,6 +79,7 @@ def update_telemetry_state(
metadata: Optional[dict] = None,
):
user: KhojUser = request.user.object if request.user.is_authenticated else None
client_app: ClientApplication = request.user.client_app if request.user.is_authenticated else None
subscription: Subscription = user.subscription if user and hasattr(user, "subscription") else None
user_state = {
"client_host": request.client.host if request.client else None,
@ -83,6 +89,7 @@ def update_telemetry_state(
"server_id": str(user.uuid) if user else None,
"subscription_type": subscription.type if subscription else None,
"is_recurring": subscription.is_recurring if subscription else None,
"client_id": str(client_app.name) if client_app else None,
}
if metadata:
@ -113,10 +120,6 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.Default
async def construct_conversation_logs(user: KhojUser):
return (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
async def agenerate_chat_response(*args):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, generate_chat_response, *args)

View file

@ -5,6 +5,7 @@ from datetime import datetime, timezone
import stripe
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Request
from fastapi.responses import Response
from starlette.authentication import requires
from khoj.database import adapters