diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index 00ad75f1..f00fe179 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -45,6 +45,7 @@ from khoj.database.models import (
Subscription,
TextToImageModelConfig,
UserConversationConfig,
+ UserPaintModelConfig,
UserRequests,
UserSearchModelConfig,
UserVoiceModelConfig,
@@ -891,6 +892,32 @@ class ConversationAdapters:
async def aget_text_to_image_model_config():
return await TextToImageModelConfig.objects.filter().afirst()
+ @staticmethod
+ def get_text_to_image_model_options():
+ return TextToImageModelConfig.objects.all()
+
+ @staticmethod
+ def get_user_paint_model_config(user: KhojUser):
+ config = UserPaintModelConfig.objects.filter(user=user).first()
+ if not config:
+ return None
+ return config.setting
+
+ @staticmethod
+ async def aget_user_paint_model(user: KhojUser):
+ config = await UserPaintModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
+ if not config:
+ return None
+ return config.setting
+
+ @staticmethod
+ async def aset_user_paint_model(user: KhojUser, text_to_image_model_config_id: int):
+ config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst()
+ if not config:
+ return None
+ new_config, _ = await UserPaintModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
+ return new_config
+
class FileObjectAdapters:
@staticmethod
diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py
index 3bc0f76d..95e3508c 100644
--- a/src/khoj/database/admin.py
+++ b/src/khoj/database/admin.py
@@ -96,7 +96,6 @@ admin.site.register(SpeechToTextModelOptions)
admin.site.register(SearchModelConfig)
admin.site.register(ReflectiveQuestion)
admin.site.register(UserSearchModelConfig)
-admin.site.register(TextToImageModelConfig)
admin.site.register(ClientApplication)
admin.site.register(GithubConfig)
admin.site.register(NotionConfig)
@@ -153,6 +152,16 @@ class ChatModelOptionsAdmin(admin.ModelAdmin):
search_fields = ("id", "chat_model", "model_type")
+@admin.register(TextToImageModelConfig)
+class TextToImageModelOptionsAdmin(admin.ModelAdmin):
+ list_display = (
+ "id",
+ "model_name",
+ "model_type",
+ )
+ search_fields = ("id", "model_name", "model_type")
+
+
@admin.register(OpenAIProcessorConversationConfig)
class OpenAIProcessorConversationConfigAdmin(admin.ModelAdmin):
list_display = (
diff --git a/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py b/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py
index 069b16ca..d5d3e972 100644
--- a/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py
+++ b/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py
@@ -1,5 +1,7 @@
-# Generated by Django 4.2.11 on 2024-06-20 19:02
+# Generated by Django 4.2.11 on 2024-06-20 19:48
+import django.db.models.deletion
+from django.conf import settings
from django.db import migrations, models
@@ -21,4 +23,25 @@ class Migration(migrations.Migration):
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200
),
),
+ migrations.CreateModel(
+ name="UserPaintModelConfig",
+ fields=[
+ ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
+ ("created_at", models.DateTimeField(auto_now_add=True)),
+ ("updated_at", models.DateTimeField(auto_now=True)),
+ (
+ "setting",
+ models.ForeignKey(
+ on_delete=django.db.models.deletion.CASCADE, to="database.texttoimagemodelconfig"
+ ),
+ ),
+ (
+ "user",
+ models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
+ ),
+ ],
+ options={
+ "abstract": False,
+ },
+ ),
]
diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index 89f9802b..6da7c928 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -265,6 +265,11 @@ class UserSearchModelConfig(BaseModel):
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
+class UserPaintModelConfig(BaseModel):
+ user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
+ setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
+
+
class Conversation(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)
diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html
index 7747229e..d4712206 100644
--- a/src/khoj/interface/web/base_config.html
+++ b/src/khoj/interface/web/base_config.html
@@ -332,6 +332,7 @@
margin: 20px;
}
+ select#paint-models,
select#search-models,
select#voice-models,
select#chat-models {
diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html
index 2c3c98db..88725c64 100644
--- a/src/khoj/interface/web/config.html
+++ b/src/khoj/interface/web/config.html
@@ -192,11 +192,37 @@
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
-
+
+
+
+
![Chat](/static/assets/icons/chat.svg)
+
+ Paint
+
+
+
+
+
+
+ {% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
+
+ Save
+
+ {% else %}
+
Subscribe to use different models
{% endif %}
@@ -425,7 +451,7 @@
function updateChatModel() {
const chatModel = document.getElementById("chat-models").value;
- const saveModelButton = document.getElementById("save-model");
+ const saveModelButton = document.getElementById("save-chat-model");
saveModelButton.disabled = true;
saveModelButton.innerHTML = "Saving...";
@@ -491,6 +517,38 @@
})
};
+ function updatePaintModel() {
+ const paintModel = document.getElementById("paint-models").value;
+ const saveModelButton = document.getElementById("save-paint-model");
+ saveModelButton.disabled = true;
+ saveModelButton.innerHTML = "Saving...";
+
+ fetch('/api/config/data/paint/model?id=' + paintModel, {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ }
+ })
+ .then(response => response.json())
+ .then(data => {
+ if (data.status == "ok") {
+ saveModelButton.innerHTML = "Save";
+ saveModelButton.disabled = false;
+
+ let notificationBanner = document.getElementById("notification-banner");
+ notificationBanner.innerHTML = "Paint model has been updated!";
+ notificationBanner.style.display = "block";
+ setTimeout(function() {
+ notificationBanner.style.display = "none";
+ }, 5000);
+
+ } else {
+ saveModelButton.innerHTML = "Error";
+ saveModelButton.disabled = false;
+ }
+ })
+ };
+
function clearContentType(content_source) {
fetch('/api/config/data/content-source/' + content_source, {
method: 'DELETE',
diff --git a/src/khoj/routers/api_config.py b/src/khoj/routers/api_config.py
index 68757de6..f25ecacd 100644
--- a/src/khoj/routers/api_config.py
+++ b/src/khoj/routers/api_config.py
@@ -315,6 +315,35 @@ async def update_search_model(
return {"status": "ok"}
+@api_config.post("/data/paint/model", status_code=200)
+@requires(["authenticated"])
+async def update_paint_model(
+ request: Request,
+ id: str,
+ client: Optional[str] = None,
+):
+ user = request.user.object
+ subscribed = has_required_scope(request, ["premium"])
+
+ if not subscribed:
+ raise HTTPException(status_code=403, detail="User is not subscribed to premium")
+
+ new_config = await ConversationAdapters.aset_user_paint_model(user, int(id))
+
+ update_telemetry_state(
+ request=request,
+ telemetry_type="api",
+ api="set_paint_model",
+ client=client,
+ metadata={"paint_model": new_config.setting.model_name},
+ )
+
+ if new_config is None:
+ return {"status": "error", "message": "Model not found"}
+
+ return {"status": "ok"}
+
+
@api_config.get("/index/size", response_model=Dict[str, int])
@requires(["authenticated"])
async def get_indexed_data_size(request: Request, common: CommonQueryParams):
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index e7392738..976caef1 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -762,7 +762,7 @@ async def text_to_image(
image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
- text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
+ text_to_image_config = await ConversationAdapters.aget_user_paint_model(user)
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py
index 1dc6a2f4..87075b49 100644
--- a/src/khoj/routers/web_client.py
+++ b/src/khoj/routers/web_client.py
@@ -262,6 +262,12 @@ def config_page(request: Request):
current_search_model_option = adapters.get_user_search_model_or_default(user)
+ selected_paint_model_config = ConversationAdapters.get_user_paint_model_config(user)
+ paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
+ all_paint_model_options = list()
+ for paint_model in paint_model_options:
+ all_paint_model_options.append({"model_name": paint_model.model_name, "id": paint_model.id})
+
notion_oauth_url = get_notion_auth_url(user)
eleven_labs_enabled = is_eleven_labs_enabled()
@@ -284,10 +290,12 @@ def config_page(request: Request):
"anonymous_mode": state.anonymous_mode,
"username": user.username,
"given_name": given_name,
- "conversation_options": all_conversation_options,
"search_model_options": all_search_model_options,
"selected_search_model_config": current_search_model_option.id,
+ "conversation_options": all_conversation_options,
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
+ "paint_model_options": all_paint_model_options,
+ "selected_paint_model_config": selected_paint_model_config.id if selected_paint_model_config else None,
"user_photo": user_picture,
"billing_enabled": state.billing_enabled,
"subscription_state": user_subscription_state,