mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Add support for a first-party client app to call into Khoj (Part 1) (#601)
* Add support for a first party client app - Based on a client id and client secret, allow a first party app to call into the Khoj backend with a phone number identifier - Add migration to add phone numbers to the KhojUser object * Add plus in front of country code when registering a phone number. - Decrease free tier limit to 5 (from 10) - Return a response object when handling stripe webhooks * Fix telemetry method which references authenticated user's client app * Add better error handling for null phone numbers, simplify logic of authenticating user * Pull the client_secret in the API call from the authorization header * Add a migration merge to resolve phone number and other changes
This commit is contained in:
parent
9dfe1bb003
commit
039ed78253
11 changed files with 204 additions and 18 deletions
|
@ -76,6 +76,8 @@ dependencies = [
|
|||
"rapidocr-onnxruntime == 1.3.8",
|
||||
"stripe == 7.3.0",
|
||||
"openai-whisper >= 20231117",
|
||||
"django-phonenumber-field == 7.3.0",
|
||||
"phonenumbers == 8.13.27",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
|
|
|
@ -62,6 +62,7 @@ INSTALLED_APPS = [
|
|||
"django.contrib.sessions",
|
||||
"django.contrib.messages",
|
||||
"django.contrib.staticfiles",
|
||||
"phonenumber_field",
|
||||
]
|
||||
|
||||
MIDDLEWARE = [
|
||||
|
|
|
@ -9,6 +9,7 @@ import openai
|
|||
import requests
|
||||
import schedule
|
||||
from django.utils.timezone import make_aware
|
||||
from fastapi import Response
|
||||
from starlette.authentication import (
|
||||
AuthCredentials,
|
||||
AuthenticationBackend,
|
||||
|
@ -20,27 +21,32 @@ from starlette.middleware.sessions import SessionMiddleware
|
|||
from starlette.requests import HTTPConnection
|
||||
|
||||
from khoj.database.adapters import (
|
||||
ClientApplicationAdapters,
|
||||
ConversationAdapters,
|
||||
SubscriptionState,
|
||||
aget_or_create_user_by_phone_number,
|
||||
aget_user_by_phone_number,
|
||||
aget_user_subscription_state,
|
||||
get_all_users,
|
||||
get_or_create_search_models,
|
||||
)
|
||||
from khoj.database.models import KhojUser, Subscription
|
||||
from khoj.database.models import ClientApplication, KhojUser, Subscription
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.routers.indexer import configure_content, configure_search, load_content
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import SearchType
|
||||
from khoj.utils.fs_syncer import collect_files
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthenticatedKhojUser(SimpleUser):
|
||||
def __init__(self, user):
|
||||
def __init__(self, user, client_app: Optional[ClientApplication] = None):
|
||||
self.object = user
|
||||
super().__init__(user.email)
|
||||
self.client_app = client_app
|
||||
super().__init__(user.username)
|
||||
|
||||
|
||||
class UserAuthenticationBackend(AuthenticationBackend):
|
||||
|
@ -108,6 +114,53 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
if subscribed:
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
# Get query params for client_id and client_secret
|
||||
client_id = request.query_params.get("client_id")
|
||||
if client_id:
|
||||
# Get the client secret, which is passed in the Authorization header
|
||||
client_secret = request.headers["Authorization"].split("Bearer ")[1]
|
||||
if not client_secret:
|
||||
return Response(
|
||||
status_code=401,
|
||||
content="Please provide a client secret in the Authorization header with a client_id query param.",
|
||||
)
|
||||
|
||||
# Get the client application
|
||||
client_application = await ClientApplicationAdapters.aget_client_application_by_id(client_id, client_secret)
|
||||
if client_application is None:
|
||||
return AuthCredentials(), UnauthenticatedUser()
|
||||
# Get the identifier used for the user
|
||||
phone_number = request.query_params.get("phone_number")
|
||||
if is_none_or_empty(phone_number):
|
||||
return AuthCredentials(), UnauthenticatedUser()
|
||||
|
||||
if not phone_number.startswith("+"):
|
||||
phone_number = f"+{phone_number}"
|
||||
|
||||
create_if_not_exists = request.query_params.get("create_if_not_exists")
|
||||
if create_if_not_exists:
|
||||
user = await aget_or_create_user_by_phone_number(phone_number)
|
||||
else:
|
||||
user = await aget_user_by_phone_number(phone_number)
|
||||
|
||||
if user is None:
|
||||
return AuthCredentials(), UnauthenticatedUser()
|
||||
|
||||
if not state.billing_enabled:
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user, client_application)
|
||||
|
||||
subscription_state = await aget_user_subscription_state(user)
|
||||
subscribed = (
|
||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
or subscription_state == SubscriptionState.TRIAL.value
|
||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||
)
|
||||
if subscribed:
|
||||
return (
|
||||
AuthCredentials(["authenticated", "premium"]),
|
||||
AuthenticatedKhojUser(user),
|
||||
)
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application)
|
||||
if state.anonymous_mode:
|
||||
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
|
||||
if user:
|
||||
|
|
|
@ -17,6 +17,7 @@ from torch import Tensor
|
|||
|
||||
from khoj.database.models import (
|
||||
ChatModelOptions,
|
||||
ClientApplication,
|
||||
Conversation,
|
||||
Entry,
|
||||
GithubConfig,
|
||||
|
@ -40,7 +41,7 @@ from khoj.search_filter.file_filter import FileFilter
|
|||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils import state
|
||||
from khoj.utils.config import GPT4AllProcessorModel
|
||||
from khoj.utils.helpers import generate_random_name
|
||||
from khoj.utils.helpers import generate_random_name, is_none_or_empty
|
||||
|
||||
|
||||
class SubscriptionState(Enum):
|
||||
|
@ -85,6 +86,28 @@ async def get_or_create_user(token: dict) -> KhojUser:
|
|||
return user
|
||||
|
||||
|
||||
async def aget_or_create_user_by_phone_number(phone_number: str) -> KhojUser:
|
||||
if is_none_or_empty(phone_number):
|
||||
return None
|
||||
user = await aget_user_by_phone_number(phone_number)
|
||||
if not user:
|
||||
user = await acreate_user_by_phone_number(phone_number)
|
||||
return user
|
||||
|
||||
|
||||
async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
|
||||
if is_none_or_empty(phone_number):
|
||||
return None
|
||||
user, _ = await KhojUser.objects.filter(phone_number=phone_number).aupdate_or_create(
|
||||
defaults={"username": phone_number, "phone_number": phone_number}
|
||||
)
|
||||
await user.asave()
|
||||
|
||||
await Subscription.objects.acreate(user=user, type="trial")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_or_create_user_by_email(email: str) -> KhojUser:
|
||||
user, _ = await KhojUser.objects.filter(email=email).aupdate_or_create(defaults={"username": email, "email": email})
|
||||
await user.asave()
|
||||
|
@ -187,6 +210,12 @@ async def get_user_by_token(token: dict) -> KhojUser:
|
|||
return google_user.user
|
||||
|
||||
|
||||
async def aget_user_by_phone_number(phone_number: str) -> KhojUser:
|
||||
if is_none_or_empty(phone_number):
|
||||
return None
|
||||
return await KhojUser.objects.filter(phone_number=phone_number).prefetch_related("subscription").afirst()
|
||||
|
||||
|
||||
async def retrieve_user(session_id: str) -> KhojUser:
|
||||
session = SessionStore(session_key=session_id)
|
||||
if not await sync_to_async(session.exists)(session_key=session_id):
|
||||
|
@ -270,6 +299,12 @@ async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
|
|||
return new_config
|
||||
|
||||
|
||||
class ClientApplicationAdapters:
|
||||
@staticmethod
|
||||
async def aget_client_application_by_id(client_id: str, client_secret: str):
|
||||
return await ClientApplication.objects.filter(client_id=client_id, client_secret=client_secret).afirst()
|
||||
|
||||
|
||||
class ConversationAdapters:
|
||||
@staticmethod
|
||||
def get_conversation_by_user(user: KhojUser):
|
||||
|
@ -279,11 +314,11 @@ class ConversationAdapters:
|
|||
return Conversation.objects.create(user=user)
|
||||
|
||||
@staticmethod
|
||||
async def aget_conversation_by_user(user: KhojUser):
|
||||
conversation = Conversation.objects.filter(user=user)
|
||||
async def aget_conversation_by_user(user: KhojUser, client_application: ClientApplication = None):
|
||||
conversation = Conversation.objects.filter(user=user, client=client_application)
|
||||
if await conversation.aexists():
|
||||
return await conversation.afirst()
|
||||
return await Conversation.objects.acreate(user=user)
|
||||
return await Conversation.objects.acreate(user=user, client=client_application)
|
||||
|
||||
@staticmethod
|
||||
async def adelete_conversation_by_user(user: KhojUser):
|
||||
|
|
|
@ -7,6 +7,7 @@ from django.http import HttpResponse
|
|||
|
||||
from khoj.database.models import (
|
||||
ChatModelOptions,
|
||||
ClientApplication,
|
||||
Conversation,
|
||||
KhojUser,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
|
@ -19,10 +20,24 @@ from khoj.database.models import (
|
|||
UserSearchModelConfig,
|
||||
)
|
||||
|
||||
# Register your models here.
|
||||
|
||||
class KhojUserAdmin(UserAdmin):
|
||||
list_display = (
|
||||
"id",
|
||||
"email",
|
||||
"username",
|
||||
"is_active",
|
||||
"is_staff",
|
||||
"is_superuser",
|
||||
"phone_number",
|
||||
)
|
||||
search_fields = ("email", "username", "phone_number")
|
||||
filter_horizontal = ("groups", "user_permissions")
|
||||
|
||||
fieldsets = (("Personal info", {"fields": ("phone_number",)}),) + UserAdmin.fieldsets
|
||||
|
||||
|
||||
admin.site.register(KhojUser, UserAdmin)
|
||||
admin.site.register(KhojUser, KhojUserAdmin)
|
||||
|
||||
admin.site.register(ChatModelOptions)
|
||||
admin.site.register(SpeechToTextModelOptions)
|
||||
|
@ -33,6 +48,7 @@ admin.site.register(Subscription)
|
|||
admin.site.register(ReflectiveQuestion)
|
||||
admin.site.register(UserSearchModelConfig)
|
||||
admin.site.register(TextToImageModelConfig)
|
||||
admin.site.register(ClientApplication)
|
||||
|
||||
|
||||
@admin.register(Conversation)
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# Generated by Django 4.2.7 on 2024-01-04 12:22
|
||||
|
||||
import django.db.models.deletion
|
||||
import phonenumber_field.modelfields
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0024_alter_entry_embeddings"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="ClientApplication",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("name", models.CharField(max_length=200)),
|
||||
("client_id", models.CharField(max_length=200)),
|
||||
("client_secret", models.CharField(max_length=200)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="khojuser",
|
||||
name="phone_number",
|
||||
field=phonenumber_field.modelfields.PhoneNumberField(
|
||||
blank=True, default=None, max_length=128, null=True, region=None
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="conversation",
|
||||
name="client",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="database.clientapplication",
|
||||
),
|
||||
),
|
||||
]
|
13
src/khoj/database/migrations/0027_merge_20240118_1324.py
Normal file
13
src/khoj/database/migrations/0027_merge_20240118_1324.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Generated by Django 4.2.7 on 2024-01-18 13:24
|
||||
from typing import List
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0025_clientapplication_khojuser_phone_number_and_more"),
|
||||
("database", "0026_searchmodelconfig_cross_encoder_inference_endpoint_and_more"),
|
||||
]
|
||||
|
||||
operations: List[str] = []
|
|
@ -3,6 +3,7 @@ import uuid
|
|||
from django.contrib.auth.models import AbstractUser
|
||||
from django.db import models
|
||||
from pgvector.django import VectorField
|
||||
from phonenumber_field.modelfields import PhoneNumberField
|
||||
|
||||
|
||||
class BaseModel(models.Model):
|
||||
|
@ -13,8 +14,18 @@ class BaseModel(models.Model):
|
|||
abstract = True
|
||||
|
||||
|
||||
class ClientApplication(BaseModel):
|
||||
name = models.CharField(max_length=200)
|
||||
client_id = models.CharField(max_length=200)
|
||||
client_secret = models.CharField(max_length=200)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class KhojUser(AbstractUser):
|
||||
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
|
||||
phone_number = PhoneNumberField(null=True, default=None, blank=True)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if not self.uuid:
|
||||
|
@ -165,6 +176,7 @@ class UserSearchModelConfig(BaseModel):
|
|||
class Conversation(BaseModel):
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
conversation_log = models.JSONField(default=dict)
|
||||
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class ReflectiveQuestion(BaseModel):
|
||||
|
|
|
@ -359,8 +359,8 @@ async def chat(
|
|||
n: Optional[int] = 5,
|
||||
d: Optional[float] = 0.18,
|
||||
stream: Optional[bool] = False,
|
||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
|
||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60)),
|
||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24)),
|
||||
) -> Response:
|
||||
user: KhojUser = request.user.object
|
||||
q = unquote(q)
|
||||
|
@ -372,7 +372,7 @@ async def chat(
|
|||
|
||||
q = q.replace(f"/{conversation_command.value}", "").strip()
|
||||
|
||||
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||
meta_log = (await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app)).conversation_log
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
|
||||
|
@ -392,7 +392,11 @@ async def chat(
|
|||
|
||||
elif conversation_command == ConversationCommand.Notes and not await EntryAdapters.auser_has_entries(user):
|
||||
no_entries_found_format = no_entries_found.format()
|
||||
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
|
||||
if stream:
|
||||
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
|
||||
else:
|
||||
response_obj = {"response": no_entries_found_format}
|
||||
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
|
||||
|
||||
elif conversation_command == ConversationCommand.Online:
|
||||
try:
|
||||
|
|
|
@ -13,7 +13,12 @@ from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
|||
from starlette.authentication import has_required_scope
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig
|
||||
from khoj.database.models import (
|
||||
ClientApplication,
|
||||
KhojUser,
|
||||
Subscription,
|
||||
TextToImageModelConfig,
|
||||
)
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.chat_model import (
|
||||
converse_offline,
|
||||
|
@ -74,6 +79,7 @@ def update_telemetry_state(
|
|||
metadata: Optional[dict] = None,
|
||||
):
|
||||
user: KhojUser = request.user.object if request.user.is_authenticated else None
|
||||
client_app: ClientApplication = request.user.client_app if request.user.is_authenticated else None
|
||||
subscription: Subscription = user.subscription if user and hasattr(user, "subscription") else None
|
||||
user_state = {
|
||||
"client_host": request.client.host if request.client else None,
|
||||
|
@ -83,6 +89,7 @@ def update_telemetry_state(
|
|||
"server_id": str(user.uuid) if user else None,
|
||||
"subscription_type": subscription.type if subscription else None,
|
||||
"is_recurring": subscription.is_recurring if subscription else None,
|
||||
"client_id": str(client_app.name) if client_app else None,
|
||||
}
|
||||
|
||||
if metadata:
|
||||
|
@ -113,10 +120,6 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
|
|||
return ConversationCommand.Default
|
||||
|
||||
|
||||
async def construct_conversation_logs(user: KhojUser):
|
||||
return (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||
|
||||
|
||||
async def agenerate_chat_response(*args):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
||||
|
|
|
@ -5,6 +5,7 @@ from datetime import datetime, timezone
|
|||
import stripe
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import Response
|
||||
from starlette.authentication import requires
|
||||
|
||||
from khoj.database import adapters
|
||||
|
|
Loading…
Add table
Reference in a new issue