Allow user to set paint model to use from web client config page

This commit is contained in:
Debanjum Singh Solanky 2024-06-21 01:16:17 +05:30
parent eb09aba747
commit 2c4bf91a61
9 changed files with 167 additions and 7 deletions

View file

@ -45,6 +45,7 @@ from khoj.database.models import (
Subscription, Subscription,
TextToImageModelConfig, TextToImageModelConfig,
UserConversationConfig, UserConversationConfig,
UserPaintModelConfig,
UserRequests, UserRequests,
UserSearchModelConfig, UserSearchModelConfig,
UserVoiceModelConfig, UserVoiceModelConfig,
@ -891,6 +892,32 @@ class ConversationAdapters:
async def aget_text_to_image_model_config(): async def aget_text_to_image_model_config():
return await TextToImageModelConfig.objects.filter().afirst() return await TextToImageModelConfig.objects.filter().afirst()
@staticmethod
def get_text_to_image_model_options():
return TextToImageModelConfig.objects.all()
@staticmethod
def get_user_paint_model_config(user: KhojUser):
config = UserPaintModelConfig.objects.filter(user=user).first()
if not config:
return None
return config.setting
@staticmethod
async def aget_user_paint_model(user: KhojUser):
config = await UserPaintModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
return None
return config.setting
@staticmethod
async def aset_user_paint_model(user: KhojUser, text_to_image_model_config_id: int):
config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst()
if not config:
return None
new_config, _ = await UserPaintModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config
class FileObjectAdapters: class FileObjectAdapters:
@staticmethod @staticmethod

View file

@ -96,7 +96,6 @@ admin.site.register(SpeechToTextModelOptions)
admin.site.register(SearchModelConfig) admin.site.register(SearchModelConfig)
admin.site.register(ReflectiveQuestion) admin.site.register(ReflectiveQuestion)
admin.site.register(UserSearchModelConfig) admin.site.register(UserSearchModelConfig)
admin.site.register(TextToImageModelConfig)
admin.site.register(ClientApplication) admin.site.register(ClientApplication)
admin.site.register(GithubConfig) admin.site.register(GithubConfig)
admin.site.register(NotionConfig) admin.site.register(NotionConfig)
@ -153,6 +152,16 @@ class ChatModelOptionsAdmin(admin.ModelAdmin):
search_fields = ("id", "chat_model", "model_type") search_fields = ("id", "chat_model", "model_type")
@admin.register(TextToImageModelConfig)
class TextToImageModelOptionsAdmin(admin.ModelAdmin):
list_display = (
"id",
"model_name",
"model_type",
)
search_fields = ("id", "model_name", "model_type")
@admin.register(OpenAIProcessorConversationConfig) @admin.register(OpenAIProcessorConversationConfig)
class OpenAIProcessorConversationConfigAdmin(admin.ModelAdmin): class OpenAIProcessorConversationConfigAdmin(admin.ModelAdmin):
list_display = ( list_display = (

View file

@ -1,5 +1,7 @@
# Generated by Django 4.2.11 on 2024-06-20 19:02 # Generated by Django 4.2.11 on 2024-06-20 19:48
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
@ -21,4 +23,25 @@ class Migration(migrations.Migration):
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200 choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200
), ),
), ),
migrations.CreateModel(
name="UserPaintModelConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"setting",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="database.texttoimagemodelconfig"
),
),
(
"user",
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
),
],
options={
"abstract": False,
},
),
] ]

View file

@ -265,6 +265,11 @@ class UserSearchModelConfig(BaseModel):
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE) setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
class UserPaintModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
class Conversation(BaseModel): class Conversation(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict) conversation_log = models.JSONField(default=dict)

View file

@ -332,6 +332,7 @@
margin: 20px; margin: 20px;
} }
select#paint-models,
select#search-models, select#search-models,
select#voice-models, select#voice-models,
select#chat-models { select#chat-models {

View file

@ -192,11 +192,37 @@
</div> </div>
<div class="card-action-row"> <div class="card-action-row">
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %} {% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
<button id="save-model" class="card-button happy" onclick="updateChatModel()"> <button id="save-chat-model" class="card-button happy" onclick="updateChatModel()">
Save Save
</button> </button>
{% else %} {% else %}
<button id="save-model" class="card-button" disabled> <button id="save-chat-model" class="card-button" disabled>
Subscribe to use different models
</button>
{% endif %}
</div>
</div>
<div class="card">
<div class="card-title-row">
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
<h3 class="card-title">
<span>Paint</span>
</h3>
</div>
<div class="card-description-row">
<select id="paint-models">
{% for option in paint_model_options %}
<option value="{{ option.id }}" {% if option.id == selected_paint_model_config %}selected{% endif %}>{{ option.model_name }}</option>
{% endfor %}
</select>
</div>
<div class="card-action-row">
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
<button id="save-paint-model" class="card-button happy" onclick="updatePaintModel()">
Save
</button>
{% else %}
<button id="save-paint-model" class="card-button" disabled>
Subscribe to use different models Subscribe to use different models
</button> </button>
{% endif %} {% endif %}
@ -425,7 +451,7 @@
function updateChatModel() { function updateChatModel() {
const chatModel = document.getElementById("chat-models").value; const chatModel = document.getElementById("chat-models").value;
const saveModelButton = document.getElementById("save-model"); const saveModelButton = document.getElementById("save-chat-model");
saveModelButton.disabled = true; saveModelButton.disabled = true;
saveModelButton.innerHTML = "Saving..."; saveModelButton.innerHTML = "Saving...";
@ -491,6 +517,38 @@
}) })
}; };
function updatePaintModel() {
const paintModel = document.getElementById("paint-models").value;
const saveModelButton = document.getElementById("save-paint-model");
saveModelButton.disabled = true;
saveModelButton.innerHTML = "Saving...";
fetch('/api/config/data/paint/model?id=' + paintModel, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
}
})
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
saveModelButton.innerHTML = "Save";
saveModelButton.disabled = false;
let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "Paint model has been updated!";
notificationBanner.style.display = "block";
setTimeout(function() {
notificationBanner.style.display = "none";
}, 5000);
} else {
saveModelButton.innerHTML = "Error";
saveModelButton.disabled = false;
}
})
};
function clearContentType(content_source) { function clearContentType(content_source) {
fetch('/api/config/data/content-source/' + content_source, { fetch('/api/config/data/content-source/' + content_source, {
method: 'DELETE', method: 'DELETE',

View file

@ -315,6 +315,35 @@ async def update_search_model(
return {"status": "ok"} return {"status": "ok"}
@api_config.post("/data/paint/model", status_code=200)
@requires(["authenticated"])
async def update_paint_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
if not subscribed:
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
new_config = await ConversationAdapters.aset_user_paint_model(user, int(id))
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_paint_model",
client=client,
metadata={"paint_model": new_config.setting.model_name},
)
if new_config is None:
return {"status": "error", "message": "Model not found"}
return {"status": "ok"}
@api_config.get("/index/size", response_model=Dict[str, int]) @api_config.get("/index/size", response_model=Dict[str, int])
@requires(["authenticated"]) @requires(["authenticated"])
async def get_indexed_data_size(request: Request, common: CommonQueryParams): async def get_indexed_data_size(request: Request, common: CommonQueryParams):

View file

@ -762,7 +762,7 @@ async def text_to_image(
image_url = None image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() text_to_image_config = await ConversationAdapters.aget_user_paint_model(user)
if not text_to_image_config: if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error # If the user has not configured a text to image model, return an unsupported on server error
status_code = 501 status_code = 501

View file

@ -262,6 +262,12 @@ def config_page(request: Request):
current_search_model_option = adapters.get_user_search_model_or_default(user) current_search_model_option = adapters.get_user_search_model_or_default(user)
selected_paint_model_config = ConversationAdapters.get_user_paint_model_config(user)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
all_paint_model_options = list()
for paint_model in paint_model_options:
all_paint_model_options.append({"model_name": paint_model.model_name, "id": paint_model.id})
notion_oauth_url = get_notion_auth_url(user) notion_oauth_url = get_notion_auth_url(user)
eleven_labs_enabled = is_eleven_labs_enabled() eleven_labs_enabled = is_eleven_labs_enabled()
@ -284,10 +290,12 @@ def config_page(request: Request):
"anonymous_mode": state.anonymous_mode, "anonymous_mode": state.anonymous_mode,
"username": user.username, "username": user.username,
"given_name": given_name, "given_name": given_name,
"conversation_options": all_conversation_options,
"search_model_options": all_search_model_options, "search_model_options": all_search_model_options,
"selected_search_model_config": current_search_model_option.id, "selected_search_model_config": current_search_model_option.id,
"conversation_options": all_conversation_options,
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None, "selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
"paint_model_options": all_paint_model_options,
"selected_paint_model_config": selected_paint_model_config.id if selected_paint_model_config else None,
"user_photo": user_picture, "user_photo": user_picture,
"billing_enabled": state.billing_enabled, "billing_enabled": state.billing_enabled,
"subscription_state": user_subscription_state, "subscription_state": user_subscription_state,