mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Allow user to set paint model to use from web client config page
This commit is contained in:
parent
eb09aba747
commit
2c4bf91a61
9 changed files with 167 additions and 7 deletions
|
@ -45,6 +45,7 @@ from khoj.database.models import (
|
||||||
Subscription,
|
Subscription,
|
||||||
TextToImageModelConfig,
|
TextToImageModelConfig,
|
||||||
UserConversationConfig,
|
UserConversationConfig,
|
||||||
|
UserPaintModelConfig,
|
||||||
UserRequests,
|
UserRequests,
|
||||||
UserSearchModelConfig,
|
UserSearchModelConfig,
|
||||||
UserVoiceModelConfig,
|
UserVoiceModelConfig,
|
||||||
|
@ -891,6 +892,32 @@ class ConversationAdapters:
|
||||||
async def aget_text_to_image_model_config():
|
async def aget_text_to_image_model_config():
|
||||||
return await TextToImageModelConfig.objects.filter().afirst()
|
return await TextToImageModelConfig.objects.filter().afirst()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_text_to_image_model_options():
|
||||||
|
return TextToImageModelConfig.objects.all()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_paint_model_config(user: KhojUser):
|
||||||
|
config = UserPaintModelConfig.objects.filter(user=user).first()
|
||||||
|
if not config:
|
||||||
|
return None
|
||||||
|
return config.setting
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aget_user_paint_model(user: KhojUser):
|
||||||
|
config = await UserPaintModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||||
|
if not config:
|
||||||
|
return None
|
||||||
|
return config.setting
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aset_user_paint_model(user: KhojUser, text_to_image_model_config_id: int):
|
||||||
|
config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst()
|
||||||
|
if not config:
|
||||||
|
return None
|
||||||
|
new_config, _ = await UserPaintModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
||||||
|
return new_config
|
||||||
|
|
||||||
|
|
||||||
class FileObjectAdapters:
|
class FileObjectAdapters:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -96,7 +96,6 @@ admin.site.register(SpeechToTextModelOptions)
|
||||||
admin.site.register(SearchModelConfig)
|
admin.site.register(SearchModelConfig)
|
||||||
admin.site.register(ReflectiveQuestion)
|
admin.site.register(ReflectiveQuestion)
|
||||||
admin.site.register(UserSearchModelConfig)
|
admin.site.register(UserSearchModelConfig)
|
||||||
admin.site.register(TextToImageModelConfig)
|
|
||||||
admin.site.register(ClientApplication)
|
admin.site.register(ClientApplication)
|
||||||
admin.site.register(GithubConfig)
|
admin.site.register(GithubConfig)
|
||||||
admin.site.register(NotionConfig)
|
admin.site.register(NotionConfig)
|
||||||
|
@ -153,6 +152,16 @@ class ChatModelOptionsAdmin(admin.ModelAdmin):
|
||||||
search_fields = ("id", "chat_model", "model_type")
|
search_fields = ("id", "chat_model", "model_type")
|
||||||
|
|
||||||
|
|
||||||
|
@admin.register(TextToImageModelConfig)
|
||||||
|
class TextToImageModelOptionsAdmin(admin.ModelAdmin):
|
||||||
|
list_display = (
|
||||||
|
"id",
|
||||||
|
"model_name",
|
||||||
|
"model_type",
|
||||||
|
)
|
||||||
|
search_fields = ("id", "model_name", "model_type")
|
||||||
|
|
||||||
|
|
||||||
@admin.register(OpenAIProcessorConversationConfig)
|
@admin.register(OpenAIProcessorConversationConfig)
|
||||||
class OpenAIProcessorConversationConfigAdmin(admin.ModelAdmin):
|
class OpenAIProcessorConversationConfigAdmin(admin.ModelAdmin):
|
||||||
list_display = (
|
list_display = (
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
# Generated by Django 4.2.11 on 2024-06-20 19:02
|
# Generated by Django 4.2.11 on 2024-06-20 19:48
|
||||||
|
|
||||||
|
import django.db.models.deletion
|
||||||
|
from django.conf import settings
|
||||||
from django.db import migrations, models
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,4 +23,25 @@ class Migration(migrations.Migration):
|
||||||
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200
|
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="UserPaintModelConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
(
|
||||||
|
"setting",
|
||||||
|
models.ForeignKey(
|
||||||
|
on_delete=django.db.models.deletion.CASCADE, to="database.texttoimagemodelconfig"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"user",
|
||||||
|
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -265,6 +265,11 @@ class UserSearchModelConfig(BaseModel):
|
||||||
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class UserPaintModelConfig(BaseModel):
|
||||||
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||||
|
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
class Conversation(BaseModel):
|
class Conversation(BaseModel):
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
conversation_log = models.JSONField(default=dict)
|
conversation_log = models.JSONField(default=dict)
|
||||||
|
|
|
@ -332,6 +332,7 @@
|
||||||
margin: 20px;
|
margin: 20px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
select#paint-models,
|
||||||
select#search-models,
|
select#search-models,
|
||||||
select#voice-models,
|
select#voice-models,
|
||||||
select#chat-models {
|
select#chat-models {
|
||||||
|
|
|
@ -192,11 +192,37 @@
|
||||||
</div>
|
</div>
|
||||||
<div class="card-action-row">
|
<div class="card-action-row">
|
||||||
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
|
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
|
||||||
<button id="save-model" class="card-button happy" onclick="updateChatModel()">
|
<button id="save-chat-model" class="card-button happy" onclick="updateChatModel()">
|
||||||
Save
|
Save
|
||||||
</button>
|
</button>
|
||||||
{% else %}
|
{% else %}
|
||||||
<button id="save-model" class="card-button" disabled>
|
<button id="save-chat-model" class="card-button" disabled>
|
||||||
|
Subscribe to use different models
|
||||||
|
</button>
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="card">
|
||||||
|
<div class="card-title-row">
|
||||||
|
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
||||||
|
<h3 class="card-title">
|
||||||
|
<span>Paint</span>
|
||||||
|
</h3>
|
||||||
|
</div>
|
||||||
|
<div class="card-description-row">
|
||||||
|
<select id="paint-models">
|
||||||
|
{% for option in paint_model_options %}
|
||||||
|
<option value="{{ option.id }}" {% if option.id == selected_paint_model_config %}selected{% endif %}>{{ option.model_name }}</option>
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
<div class="card-action-row">
|
||||||
|
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
|
||||||
|
<button id="save-paint-model" class="card-button happy" onclick="updatePaintModel()">
|
||||||
|
Save
|
||||||
|
</button>
|
||||||
|
{% else %}
|
||||||
|
<button id="save-paint-model" class="card-button" disabled>
|
||||||
Subscribe to use different models
|
Subscribe to use different models
|
||||||
</button>
|
</button>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
@ -425,7 +451,7 @@
|
||||||
|
|
||||||
function updateChatModel() {
|
function updateChatModel() {
|
||||||
const chatModel = document.getElementById("chat-models").value;
|
const chatModel = document.getElementById("chat-models").value;
|
||||||
const saveModelButton = document.getElementById("save-model");
|
const saveModelButton = document.getElementById("save-chat-model");
|
||||||
saveModelButton.disabled = true;
|
saveModelButton.disabled = true;
|
||||||
saveModelButton.innerHTML = "Saving...";
|
saveModelButton.innerHTML = "Saving...";
|
||||||
|
|
||||||
|
@ -491,6 +517,38 @@
|
||||||
})
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
|
function updatePaintModel() {
|
||||||
|
const paintModel = document.getElementById("paint-models").value;
|
||||||
|
const saveModelButton = document.getElementById("save-paint-model");
|
||||||
|
saveModelButton.disabled = true;
|
||||||
|
saveModelButton.innerHTML = "Saving...";
|
||||||
|
|
||||||
|
fetch('/api/config/data/paint/model?id=' + paintModel, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => {
|
||||||
|
if (data.status == "ok") {
|
||||||
|
saveModelButton.innerHTML = "Save";
|
||||||
|
saveModelButton.disabled = false;
|
||||||
|
|
||||||
|
let notificationBanner = document.getElementById("notification-banner");
|
||||||
|
notificationBanner.innerHTML = "Paint model has been updated!";
|
||||||
|
notificationBanner.style.display = "block";
|
||||||
|
setTimeout(function() {
|
||||||
|
notificationBanner.style.display = "none";
|
||||||
|
}, 5000);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
saveModelButton.innerHTML = "Error";
|
||||||
|
saveModelButton.disabled = false;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
function clearContentType(content_source) {
|
function clearContentType(content_source) {
|
||||||
fetch('/api/config/data/content-source/' + content_source, {
|
fetch('/api/config/data/content-source/' + content_source, {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
|
|
|
@ -315,6 +315,35 @@ async def update_search_model(
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@api_config.post("/data/paint/model", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def update_paint_model(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
subscribed = has_required_scope(request, ["premium"])
|
||||||
|
|
||||||
|
if not subscribed:
|
||||||
|
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
|
||||||
|
|
||||||
|
new_config = await ConversationAdapters.aset_user_paint_model(user, int(id))
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="set_paint_model",
|
||||||
|
client=client,
|
||||||
|
metadata={"paint_model": new_config.setting.model_name},
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_config is None:
|
||||||
|
return {"status": "error", "message": "Model not found"}
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
@api_config.get("/index/size", response_model=Dict[str, int])
|
@api_config.get("/index/size", response_model=Dict[str, int])
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def get_indexed_data_size(request: Request, common: CommonQueryParams):
|
async def get_indexed_data_size(request: Request, common: CommonQueryParams):
|
||||||
|
|
|
@ -762,7 +762,7 @@ async def text_to_image(
|
||||||
image_url = None
|
image_url = None
|
||||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||||
|
|
||||||
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
text_to_image_config = await ConversationAdapters.aget_user_paint_model(user)
|
||||||
if not text_to_image_config:
|
if not text_to_image_config:
|
||||||
# If the user has not configured a text to image model, return an unsupported on server error
|
# If the user has not configured a text to image model, return an unsupported on server error
|
||||||
status_code = 501
|
status_code = 501
|
||||||
|
|
|
@ -262,6 +262,12 @@ def config_page(request: Request):
|
||||||
|
|
||||||
current_search_model_option = adapters.get_user_search_model_or_default(user)
|
current_search_model_option = adapters.get_user_search_model_or_default(user)
|
||||||
|
|
||||||
|
selected_paint_model_config = ConversationAdapters.get_user_paint_model_config(user)
|
||||||
|
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
||||||
|
all_paint_model_options = list()
|
||||||
|
for paint_model in paint_model_options:
|
||||||
|
all_paint_model_options.append({"model_name": paint_model.model_name, "id": paint_model.id})
|
||||||
|
|
||||||
notion_oauth_url = get_notion_auth_url(user)
|
notion_oauth_url = get_notion_auth_url(user)
|
||||||
|
|
||||||
eleven_labs_enabled = is_eleven_labs_enabled()
|
eleven_labs_enabled = is_eleven_labs_enabled()
|
||||||
|
@ -284,10 +290,12 @@ def config_page(request: Request):
|
||||||
"anonymous_mode": state.anonymous_mode,
|
"anonymous_mode": state.anonymous_mode,
|
||||||
"username": user.username,
|
"username": user.username,
|
||||||
"given_name": given_name,
|
"given_name": given_name,
|
||||||
"conversation_options": all_conversation_options,
|
|
||||||
"search_model_options": all_search_model_options,
|
"search_model_options": all_search_model_options,
|
||||||
"selected_search_model_config": current_search_model_option.id,
|
"selected_search_model_config": current_search_model_option.id,
|
||||||
|
"conversation_options": all_conversation_options,
|
||||||
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
|
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
|
||||||
|
"paint_model_options": all_paint_model_options,
|
||||||
|
"selected_paint_model_config": selected_paint_model_config.id if selected_paint_model_config else None,
|
||||||
"user_photo": user_picture,
|
"user_photo": user_picture,
|
||||||
"billing_enabled": state.billing_enabled,
|
"billing_enabled": state.billing_enabled,
|
||||||
"subscription_state": user_subscription_state,
|
"subscription_state": user_subscription_state,
|
||||||
|
|
Loading…
Reference in a new issue