diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 00ad75f1..f00fe179 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -45,6 +45,7 @@ from khoj.database.models import ( Subscription, TextToImageModelConfig, UserConversationConfig, + UserPaintModelConfig, UserRequests, UserSearchModelConfig, UserVoiceModelConfig, @@ -891,6 +892,32 @@ class ConversationAdapters: async def aget_text_to_image_model_config(): return await TextToImageModelConfig.objects.filter().afirst() + @staticmethod + def get_text_to_image_model_options(): + return TextToImageModelConfig.objects.all() + + @staticmethod + def get_user_paint_model_config(user: KhojUser): + config = UserPaintModelConfig.objects.filter(user=user).first() + if not config: + return None + return config.setting + + @staticmethod + async def aget_user_paint_model(user: KhojUser): + config = await UserPaintModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() + if not config: + return None + return config.setting + + @staticmethod + async def aset_user_paint_model(user: KhojUser, text_to_image_model_config_id: int): + config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst() + if not config: + return None + new_config, _ = await UserPaintModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) + return new_config + class FileObjectAdapters: @staticmethod diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 3bc0f76d..95e3508c 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -96,7 +96,6 @@ admin.site.register(SpeechToTextModelOptions) admin.site.register(SearchModelConfig) admin.site.register(ReflectiveQuestion) admin.site.register(UserSearchModelConfig) -admin.site.register(TextToImageModelConfig) admin.site.register(ClientApplication) admin.site.register(GithubConfig) admin.site.register(NotionConfig) @@ -153,6 +152,16 @@ class ChatModelOptionsAdmin(admin.ModelAdmin): search_fields = ("id", "chat_model", "model_type") +@admin.register(TextToImageModelConfig) +class TextToImageModelOptionsAdmin(admin.ModelAdmin): + list_display = ( + "id", + "model_name", + "model_type", + ) + search_fields = ("id", "model_name", "model_type") + + @admin.register(OpenAIProcessorConversationConfig) class OpenAIProcessorConversationConfigAdmin(admin.ModelAdmin): list_display = ( diff --git a/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py b/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py index 069b16ca..d5d3e972 100644 --- a/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py +++ b/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py @@ -1,5 +1,7 @@ -# Generated by Django 4.2.11 on 2024-06-20 19:02 +# Generated by Django 4.2.11 on 2024-06-20 19:48 +import django.db.models.deletion +from django.conf import settings from django.db import migrations, models @@ -21,4 +23,25 @@ class Migration(migrations.Migration): choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200 ), ), + migrations.CreateModel( + name="UserPaintModelConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "setting", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="database.texttoimagemodelconfig" + ), + ), + ( + "user", + models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL), + ), + ], + options={ + "abstract": False, + }, + ), ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 89f9802b..6da7c928 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -265,6 +265,11 @@ class UserSearchModelConfig(BaseModel): setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE) +class UserPaintModelConfig(BaseModel): + user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) + setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE) + + class Conversation(BaseModel): user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) conversation_log = models.JSONField(default=dict) diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html index 7747229e..d4712206 100644 --- a/src/khoj/interface/web/base_config.html +++ b/src/khoj/interface/web/base_config.html @@ -332,6 +332,7 @@ margin: 20px; } + select#paint-models, select#search-models, select#voice-models, select#chat-models { diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 2c3c98db..88725c64 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -192,11 +192,37 @@
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %} - {% else %} - + {% endif %} +
+ +
+
+ Chat +

+ Paint +

+
+
+ +
+
+ {% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %} + + {% else %} + {% endif %} @@ -425,7 +451,7 @@ function updateChatModel() { const chatModel = document.getElementById("chat-models").value; - const saveModelButton = document.getElementById("save-model"); + const saveModelButton = document.getElementById("save-chat-model"); saveModelButton.disabled = true; saveModelButton.innerHTML = "Saving..."; @@ -491,6 +517,38 @@ }) }; + function updatePaintModel() { + const paintModel = document.getElementById("paint-models").value; + const saveModelButton = document.getElementById("save-paint-model"); + saveModelButton.disabled = true; + saveModelButton.innerHTML = "Saving..."; + + fetch('/api/config/data/paint/model?id=' + paintModel, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + } + }) + .then(response => response.json()) + .then(data => { + if (data.status == "ok") { + saveModelButton.innerHTML = "Save"; + saveModelButton.disabled = false; + + let notificationBanner = document.getElementById("notification-banner"); + notificationBanner.innerHTML = "Paint model has been updated!"; + notificationBanner.style.display = "block"; + setTimeout(function() { + notificationBanner.style.display = "none"; + }, 5000); + + } else { + saveModelButton.innerHTML = "Error"; + saveModelButton.disabled = false; + } + }) + }; + function clearContentType(content_source) { fetch('/api/config/data/content-source/' + content_source, { method: 'DELETE', diff --git a/src/khoj/routers/api_config.py b/src/khoj/routers/api_config.py index 68757de6..f25ecacd 100644 --- a/src/khoj/routers/api_config.py +++ b/src/khoj/routers/api_config.py @@ -315,6 +315,35 @@ async def update_search_model( return {"status": "ok"} +@api_config.post("/data/paint/model", status_code=200) +@requires(["authenticated"]) +async def update_paint_model( + request: Request, + id: str, + client: Optional[str] = None, +): + user = request.user.object + subscribed = has_required_scope(request, ["premium"]) + + if not subscribed: + raise HTTPException(status_code=403, detail="User is not subscribed to premium") + + new_config = await ConversationAdapters.aset_user_paint_model(user, int(id)) + + update_telemetry_state( + request=request, + telemetry_type="api", + api="set_paint_model", + client=client, + metadata={"paint_model": new_config.setting.model_name}, + ) + + if new_config is None: + return {"status": "error", "message": "Model not found"} + + return {"status": "ok"} + + @api_config.get("/index/size", response_model=Dict[str, int]) @requires(["authenticated"]) async def get_indexed_data_size(request: Request, common: CommonQueryParams): diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index e7392738..976caef1 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -762,7 +762,7 @@ async def text_to_image( image_url = None intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 - text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() + text_to_image_config = await ConversationAdapters.aget_user_paint_model(user) if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 1dc6a2f4..87075b49 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -262,6 +262,12 @@ def config_page(request: Request): current_search_model_option = adapters.get_user_search_model_or_default(user) + selected_paint_model_config = ConversationAdapters.get_user_paint_model_config(user) + paint_model_options = ConversationAdapters.get_text_to_image_model_options().all() + all_paint_model_options = list() + for paint_model in paint_model_options: + all_paint_model_options.append({"model_name": paint_model.model_name, "id": paint_model.id}) + notion_oauth_url = get_notion_auth_url(user) eleven_labs_enabled = is_eleven_labs_enabled() @@ -284,10 +290,12 @@ def config_page(request: Request): "anonymous_mode": state.anonymous_mode, "username": user.username, "given_name": given_name, - "conversation_options": all_conversation_options, "search_model_options": all_search_model_options, "selected_search_model_config": current_search_model_option.id, + "conversation_options": all_conversation_options, "selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None, + "paint_model_options": all_paint_model_options, + "selected_paint_model_config": selected_paint_model_config.id if selected_paint_model_config else None, "user_photo": user_picture, "billing_enabled": state.billing_enabled, "subscription_state": user_subscription_state,