mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Allow user to set paint model to use from web client config page
This commit is contained in:
parent
eb09aba747
commit
2c4bf91a61
9 changed files with 167 additions and 7 deletions
|
@ -45,6 +45,7 @@ from khoj.database.models import (
|
|||
Subscription,
|
||||
TextToImageModelConfig,
|
||||
UserConversationConfig,
|
||||
UserPaintModelConfig,
|
||||
UserRequests,
|
||||
UserSearchModelConfig,
|
||||
UserVoiceModelConfig,
|
||||
|
@ -891,6 +892,32 @@ class ConversationAdapters:
|
|||
async def aget_text_to_image_model_config():
|
||||
return await TextToImageModelConfig.objects.filter().afirst()
|
||||
|
||||
@staticmethod
|
||||
def get_text_to_image_model_options():
|
||||
return TextToImageModelConfig.objects.all()
|
||||
|
||||
@staticmethod
|
||||
def get_user_paint_model_config(user: KhojUser):
|
||||
config = UserPaintModelConfig.objects.filter(user=user).first()
|
||||
if not config:
|
||||
return None
|
||||
return config.setting
|
||||
|
||||
@staticmethod
|
||||
async def aget_user_paint_model(user: KhojUser):
|
||||
config = await UserPaintModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||
if not config:
|
||||
return None
|
||||
return config.setting
|
||||
|
||||
@staticmethod
|
||||
async def aset_user_paint_model(user: KhojUser, text_to_image_model_config_id: int):
|
||||
config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst()
|
||||
if not config:
|
||||
return None
|
||||
new_config, _ = await UserPaintModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
||||
return new_config
|
||||
|
||||
|
||||
class FileObjectAdapters:
|
||||
@staticmethod
|
||||
|
|
|
@ -96,7 +96,6 @@ admin.site.register(SpeechToTextModelOptions)
|
|||
admin.site.register(SearchModelConfig)
|
||||
admin.site.register(ReflectiveQuestion)
|
||||
admin.site.register(UserSearchModelConfig)
|
||||
admin.site.register(TextToImageModelConfig)
|
||||
admin.site.register(ClientApplication)
|
||||
admin.site.register(GithubConfig)
|
||||
admin.site.register(NotionConfig)
|
||||
|
@ -153,6 +152,16 @@ class ChatModelOptionsAdmin(admin.ModelAdmin):
|
|||
search_fields = ("id", "chat_model", "model_type")
|
||||
|
||||
|
||||
@admin.register(TextToImageModelConfig)
|
||||
class TextToImageModelOptionsAdmin(admin.ModelAdmin):
|
||||
list_display = (
|
||||
"id",
|
||||
"model_name",
|
||||
"model_type",
|
||||
)
|
||||
search_fields = ("id", "model_name", "model_type")
|
||||
|
||||
|
||||
@admin.register(OpenAIProcessorConversationConfig)
|
||||
class OpenAIProcessorConversationConfigAdmin(admin.ModelAdmin):
|
||||
list_display = (
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Generated by Django 4.2.11 on 2024-06-20 19:02
|
||||
# Generated by Django 4.2.11 on 2024-06-20 19:48
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
|
@ -21,4 +23,25 @@ class Migration(migrations.Migration):
|
|||
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200
|
||||
),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="UserPaintModelConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
(
|
||||
"setting",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="database.texttoimagemodelconfig"
|
||||
),
|
||||
),
|
||||
(
|
||||
"user",
|
||||
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
|
@ -265,6 +265,11 @@ class UserSearchModelConfig(BaseModel):
|
|||
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class UserPaintModelConfig(BaseModel):
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
conversation_log = models.JSONField(default=dict)
|
||||
|
|
|
@ -332,6 +332,7 @@
|
|||
margin: 20px;
|
||||
}
|
||||
|
||||
select#paint-models,
|
||||
select#search-models,
|
||||
select#voice-models,
|
||||
select#chat-models {
|
||||
|
|
|
@ -192,11 +192,37 @@
|
|||
</div>
|
||||
<div class="card-action-row">
|
||||
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
|
||||
<button id="save-model" class="card-button happy" onclick="updateChatModel()">
|
||||
<button id="save-chat-model" class="card-button happy" onclick="updateChatModel()">
|
||||
Save
|
||||
</button>
|
||||
{% else %}
|
||||
<button id="save-model" class="card-button" disabled>
|
||||
<button id="save-chat-model" class="card-button" disabled>
|
||||
Subscribe to use different models
|
||||
</button>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
<div class="card">
|
||||
<div class="card-title-row">
|
||||
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
||||
<h3 class="card-title">
|
||||
<span>Paint</span>
|
||||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<select id="paint-models">
|
||||
{% for option in paint_model_options %}
|
||||
<option value="{{ option.id }}" {% if option.id == selected_paint_model_config %}selected{% endif %}>{{ option.model_name }}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
</div>
|
||||
<div class="card-action-row">
|
||||
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
|
||||
<button id="save-paint-model" class="card-button happy" onclick="updatePaintModel()">
|
||||
Save
|
||||
</button>
|
||||
{% else %}
|
||||
<button id="save-paint-model" class="card-button" disabled>
|
||||
Subscribe to use different models
|
||||
</button>
|
||||
{% endif %}
|
||||
|
@ -425,7 +451,7 @@
|
|||
|
||||
function updateChatModel() {
|
||||
const chatModel = document.getElementById("chat-models").value;
|
||||
const saveModelButton = document.getElementById("save-model");
|
||||
const saveModelButton = document.getElementById("save-chat-model");
|
||||
saveModelButton.disabled = true;
|
||||
saveModelButton.innerHTML = "Saving...";
|
||||
|
||||
|
@ -491,6 +517,38 @@
|
|||
})
|
||||
};
|
||||
|
||||
function updatePaintModel() {
|
||||
const paintModel = document.getElementById("paint-models").value;
|
||||
const saveModelButton = document.getElementById("save-paint-model");
|
||||
saveModelButton.disabled = true;
|
||||
saveModelButton.innerHTML = "Saving...";
|
||||
|
||||
fetch('/api/config/data/paint/model?id=' + paintModel, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.status == "ok") {
|
||||
saveModelButton.innerHTML = "Save";
|
||||
saveModelButton.disabled = false;
|
||||
|
||||
let notificationBanner = document.getElementById("notification-banner");
|
||||
notificationBanner.innerHTML = "Paint model has been updated!";
|
||||
notificationBanner.style.display = "block";
|
||||
setTimeout(function() {
|
||||
notificationBanner.style.display = "none";
|
||||
}, 5000);
|
||||
|
||||
} else {
|
||||
saveModelButton.innerHTML = "Error";
|
||||
saveModelButton.disabled = false;
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
function clearContentType(content_source) {
|
||||
fetch('/api/config/data/content-source/' + content_source, {
|
||||
method: 'DELETE',
|
||||
|
|
|
@ -315,6 +315,35 @@ async def update_search_model(
|
|||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_config.post("/data/paint/model", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def update_paint_model(
|
||||
request: Request,
|
||||
id: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
|
||||
if not subscribed:
|
||||
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
|
||||
|
||||
new_config = await ConversationAdapters.aset_user_paint_model(user, int(id))
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_paint_model",
|
||||
client=client,
|
||||
metadata={"paint_model": new_config.setting.model_name},
|
||||
)
|
||||
|
||||
if new_config is None:
|
||||
return {"status": "error", "message": "Model not found"}
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_config.get("/index/size", response_model=Dict[str, int])
|
||||
@requires(["authenticated"])
|
||||
async def get_indexed_data_size(request: Request, common: CommonQueryParams):
|
||||
|
|
|
@ -762,7 +762,7 @@ async def text_to_image(
|
|||
image_url = None
|
||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||
|
||||
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
||||
text_to_image_config = await ConversationAdapters.aget_user_paint_model(user)
|
||||
if not text_to_image_config:
|
||||
# If the user has not configured a text to image model, return an unsupported on server error
|
||||
status_code = 501
|
||||
|
|
|
@ -262,6 +262,12 @@ def config_page(request: Request):
|
|||
|
||||
current_search_model_option = adapters.get_user_search_model_or_default(user)
|
||||
|
||||
selected_paint_model_config = ConversationAdapters.get_user_paint_model_config(user)
|
||||
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
||||
all_paint_model_options = list()
|
||||
for paint_model in paint_model_options:
|
||||
all_paint_model_options.append({"model_name": paint_model.model_name, "id": paint_model.id})
|
||||
|
||||
notion_oauth_url = get_notion_auth_url(user)
|
||||
|
||||
eleven_labs_enabled = is_eleven_labs_enabled()
|
||||
|
@ -284,10 +290,12 @@ def config_page(request: Request):
|
|||
"anonymous_mode": state.anonymous_mode,
|
||||
"username": user.username,
|
||||
"given_name": given_name,
|
||||
"conversation_options": all_conversation_options,
|
||||
"search_model_options": all_search_model_options,
|
||||
"selected_search_model_config": current_search_model_option.id,
|
||||
"conversation_options": all_conversation_options,
|
||||
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
|
||||
"paint_model_options": all_paint_model_options,
|
||||
"selected_paint_model_config": selected_paint_model_config.id if selected_paint_model_config else None,
|
||||
"user_photo": user_picture,
|
||||
"billing_enabled": state.billing_enabled,
|
||||
"subscription_state": user_subscription_state,
|
||||
|
|
Loading…
Reference in a new issue