mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Enable using Stable Diffusion 3 for Image Generation via API (#830)
- Support Stable Diffusion 3 via API Server Admin needs to setup model similar to DALLE-3 via Django Admin Panel - Use shorter prompt generator to prompt SD3 to create better images - Allow users to set paint model to use from web client config page
This commit is contained in:
commit
826c3dc9cc
10 changed files with 375 additions and 70 deletions
|
@ -48,6 +48,7 @@ from khoj.database.models import (
|
||||||
UserConversationConfig,
|
UserConversationConfig,
|
||||||
UserRequests,
|
UserRequests,
|
||||||
UserSearchModelConfig,
|
UserSearchModelConfig,
|
||||||
|
UserTextToImageModelConfig,
|
||||||
UserVoiceModelConfig,
|
UserVoiceModelConfig,
|
||||||
VoiceModelOption,
|
VoiceModelOption,
|
||||||
)
|
)
|
||||||
|
@ -905,6 +906,34 @@ class ConversationAdapters:
|
||||||
async def aget_text_to_image_model_config():
|
async def aget_text_to_image_model_config():
|
||||||
return await TextToImageModelConfig.objects.filter().afirst()
|
return await TextToImageModelConfig.objects.filter().afirst()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_text_to_image_model_options():
|
||||||
|
return TextToImageModelConfig.objects.all()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_text_to_image_model_config(user: KhojUser):
|
||||||
|
config = UserTextToImageModelConfig.objects.filter(user=user).first()
|
||||||
|
if not config:
|
||||||
|
return None
|
||||||
|
return config.setting
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aget_user_text_to_image_model(user: KhojUser):
|
||||||
|
config = await UserTextToImageModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||||
|
if not config:
|
||||||
|
return None
|
||||||
|
return config.setting
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aset_user_text_to_image_model(user: KhojUser, text_to_image_model_config_id: int):
|
||||||
|
config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst()
|
||||||
|
if not config:
|
||||||
|
return None
|
||||||
|
new_config, _ = await UserTextToImageModelConfig.objects.aupdate_or_create(
|
||||||
|
user=user, defaults={"setting": config}
|
||||||
|
)
|
||||||
|
return new_config
|
||||||
|
|
||||||
|
|
||||||
class FileObjectAdapters:
|
class FileObjectAdapters:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -96,7 +96,6 @@ admin.site.register(SpeechToTextModelOptions)
|
||||||
admin.site.register(SearchModelConfig)
|
admin.site.register(SearchModelConfig)
|
||||||
admin.site.register(ReflectiveQuestion)
|
admin.site.register(ReflectiveQuestion)
|
||||||
admin.site.register(UserSearchModelConfig)
|
admin.site.register(UserSearchModelConfig)
|
||||||
admin.site.register(TextToImageModelConfig)
|
|
||||||
admin.site.register(ClientApplication)
|
admin.site.register(ClientApplication)
|
||||||
admin.site.register(GithubConfig)
|
admin.site.register(GithubConfig)
|
||||||
admin.site.register(NotionConfig)
|
admin.site.register(NotionConfig)
|
||||||
|
@ -153,6 +152,16 @@ class ChatModelOptionsAdmin(admin.ModelAdmin):
|
||||||
search_fields = ("id", "chat_model", "model_type")
|
search_fields = ("id", "chat_model", "model_type")
|
||||||
|
|
||||||
|
|
||||||
|
@admin.register(TextToImageModelConfig)
|
||||||
|
class TextToImageModelOptionsAdmin(admin.ModelAdmin):
|
||||||
|
list_display = (
|
||||||
|
"id",
|
||||||
|
"model_name",
|
||||||
|
"model_type",
|
||||||
|
)
|
||||||
|
search_fields = ("id", "model_name", "model_type")
|
||||||
|
|
||||||
|
|
||||||
@admin.register(OpenAIProcessorConversationConfig)
|
@admin.register(OpenAIProcessorConversationConfig)
|
||||||
class OpenAIProcessorConversationConfigAdmin(admin.ModelAdmin):
|
class OpenAIProcessorConversationConfigAdmin(admin.ModelAdmin):
|
||||||
list_display = (
|
list_display = (
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Generated by Django 4.2.11 on 2024-06-26 03:27
|
||||||
|
|
||||||
|
import django.db.models.deletion
|
||||||
|
from django.conf import settings
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0048_voicemodeloption_uservoicemodelconfig"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="texttoimagemodelconfig",
|
||||||
|
name="api_key",
|
||||||
|
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="texttoimagemodelconfig",
|
||||||
|
name="openai_config",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to="database.openaiprocessorconversationconfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="texttoimagemodelconfig",
|
||||||
|
name="model_type",
|
||||||
|
field=models.CharField(
|
||||||
|
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="UserTextToImageModelConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
(
|
||||||
|
"setting",
|
||||||
|
models.ForeignKey(
|
||||||
|
on_delete=django.db.models.deletion.CASCADE, to="database.texttoimagemodelconfig"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"user",
|
||||||
|
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
|
@ -235,9 +235,37 @@ class SearchModelConfig(BaseModel):
|
||||||
class TextToImageModelConfig(BaseModel):
|
class TextToImageModelConfig(BaseModel):
|
||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
|
STABILITYAI = "stability-ai"
|
||||||
|
|
||||||
model_name = models.CharField(max_length=200, default="dall-e-3")
|
model_name = models.CharField(max_length=200, default="dall-e-3")
|
||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||||
|
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
openai_config = models.ForeignKey(
|
||||||
|
OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def clean(self):
|
||||||
|
# Custom validation logic
|
||||||
|
error = {}
|
||||||
|
if self.model_type == self.ModelType.OPENAI:
|
||||||
|
if self.api_key and self.openai_config:
|
||||||
|
error[
|
||||||
|
"api_key"
|
||||||
|
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
|
||||||
|
error[
|
||||||
|
"openai_config"
|
||||||
|
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
|
||||||
|
if self.model_type != self.ModelType.OPENAI:
|
||||||
|
if not self.api_key:
|
||||||
|
error["api_key"] = "The API key field must be set for non OpenAI models."
|
||||||
|
if self.openai_config:
|
||||||
|
error["openai_config"] = "OpenAI config cannot be set for non OpenAI models."
|
||||||
|
if error:
|
||||||
|
raise ValidationError(error)
|
||||||
|
|
||||||
|
def save(self, *args, **kwargs):
|
||||||
|
self.clean()
|
||||||
|
super().save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextModelOptions(BaseModel):
|
class SpeechToTextModelOptions(BaseModel):
|
||||||
|
@ -264,6 +292,11 @@ class UserSearchModelConfig(BaseModel):
|
||||||
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class UserTextToImageModelConfig(BaseModel):
|
||||||
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||||
|
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
class Conversation(BaseModel):
|
class Conversation(BaseModel):
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
conversation_log = models.JSONField(default=dict)
|
conversation_log = models.JSONField(default=dict)
|
||||||
|
|
|
@ -332,6 +332,7 @@
|
||||||
margin: 20px;
|
margin: 20px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
select#paint-models,
|
||||||
select#search-models,
|
select#search-models,
|
||||||
select#voice-models,
|
select#voice-models,
|
||||||
select#chat-models {
|
select#chat-models {
|
||||||
|
|
|
@ -192,11 +192,37 @@
|
||||||
</div>
|
</div>
|
||||||
<div class="card-action-row">
|
<div class="card-action-row">
|
||||||
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
|
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
|
||||||
<button id="save-model" class="card-button happy" onclick="updateChatModel()">
|
<button id="save-chat-model" class="card-button happy" onclick="updateChatModel()">
|
||||||
Save
|
Save
|
||||||
</button>
|
</button>
|
||||||
{% else %}
|
{% else %}
|
||||||
<button id="save-model" class="card-button" disabled>
|
<button id="save-chat-model" class="card-button" disabled>
|
||||||
|
Subscribe to use different models
|
||||||
|
</button>
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="card">
|
||||||
|
<div class="card-title-row">
|
||||||
|
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
||||||
|
<h3 class="card-title">
|
||||||
|
<span>Paint</span>
|
||||||
|
</h3>
|
||||||
|
</div>
|
||||||
|
<div class="card-description-row">
|
||||||
|
<select id="paint-models">
|
||||||
|
{% for option in paint_model_options %}
|
||||||
|
<option value="{{ option.id }}" {% if option.id == selected_paint_model_config %}selected{% endif %}>{{ option.model_name }}</option>
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
<div class="card-action-row">
|
||||||
|
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
|
||||||
|
<button id="save-paint-model" class="card-button happy" onclick="updatePaintModel()">
|
||||||
|
Save
|
||||||
|
</button>
|
||||||
|
{% else %}
|
||||||
|
<button id="save-paint-model" class="card-button" disabled>
|
||||||
Subscribe to use different models
|
Subscribe to use different models
|
||||||
</button>
|
</button>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
@ -425,7 +451,7 @@
|
||||||
|
|
||||||
function updateChatModel() {
|
function updateChatModel() {
|
||||||
const chatModel = document.getElementById("chat-models").value;
|
const chatModel = document.getElementById("chat-models").value;
|
||||||
const saveModelButton = document.getElementById("save-model");
|
const saveModelButton = document.getElementById("save-chat-model");
|
||||||
saveModelButton.disabled = true;
|
saveModelButton.disabled = true;
|
||||||
saveModelButton.innerHTML = "Saving...";
|
saveModelButton.innerHTML = "Saving...";
|
||||||
|
|
||||||
|
@ -491,6 +517,38 @@
|
||||||
})
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
|
function updatePaintModel() {
|
||||||
|
const paintModel = document.getElementById("paint-models").value;
|
||||||
|
const saveModelButton = document.getElementById("save-paint-model");
|
||||||
|
saveModelButton.disabled = true;
|
||||||
|
saveModelButton.innerHTML = "Saving...";
|
||||||
|
|
||||||
|
fetch('/api/config/data/paint/model?id=' + paintModel, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => {
|
||||||
|
if (data.status == "ok") {
|
||||||
|
saveModelButton.innerHTML = "Save";
|
||||||
|
saveModelButton.disabled = false;
|
||||||
|
|
||||||
|
let notificationBanner = document.getElementById("notification-banner");
|
||||||
|
notificationBanner.innerHTML = "Paint model has been updated!";
|
||||||
|
notificationBanner.style.display = "block";
|
||||||
|
setTimeout(function() {
|
||||||
|
notificationBanner.style.display = "none";
|
||||||
|
}, 5000);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
saveModelButton.innerHTML = "Error";
|
||||||
|
saveModelButton.disabled = false;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
function clearContentType(content_source) {
|
function clearContentType(content_source) {
|
||||||
fetch('/api/config/data/content-source/' + content_source, {
|
fetch('/api/config/data/content-source/' + content_source, {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
|
|
|
@ -121,7 +121,7 @@ User's Notes:
|
||||||
## Image Generation
|
## Image Generation
|
||||||
## --
|
## --
|
||||||
|
|
||||||
image_generation_improve_prompt = PromptTemplate.from_template(
|
image_generation_improve_prompt_dalle = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. You are provided with the following information to help you generate the prompt:
|
You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. You are provided with the following information to help you generate the prompt:
|
||||||
|
|
||||||
|
@ -143,6 +143,35 @@ Remember, now you are generating a prompt to improve the image generation. Add a
|
||||||
Improved Query:"""
|
Improved Query:"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
image_generation_improve_prompt_sd = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
You are a talented creator. Write 2-5 sentences with precise image composition, position details to create an image.
|
||||||
|
Use the provided context below to add specific, fine details to the image composition.
|
||||||
|
Retain any important information and follow any instructions from the original prompt.
|
||||||
|
Put any text to be rendered in the image within double quotes in your improved prompt.
|
||||||
|
You are provided with the following context to help enhance the original prompt:
|
||||||
|
|
||||||
|
Today's Date: {current_date}
|
||||||
|
User's Location: {location}
|
||||||
|
|
||||||
|
User's Notes:
|
||||||
|
{references}
|
||||||
|
|
||||||
|
Online References:
|
||||||
|
{online_results}
|
||||||
|
|
||||||
|
Conversation Log:
|
||||||
|
{chat_history}
|
||||||
|
|
||||||
|
Original Prompt: "{query}"
|
||||||
|
|
||||||
|
Now create an improved prompt using the context provided above to generate an image.
|
||||||
|
Retain any important information and follow any instructions from the original prompt.
|
||||||
|
Use the additional context from the user's notes, online references and conversation log to improve the image generation.
|
||||||
|
|
||||||
|
Improved Prompt:"""
|
||||||
|
)
|
||||||
|
|
||||||
## Online Search Conversation
|
## Online Search Conversation
|
||||||
## --
|
## --
|
||||||
online_search_conversation = PromptTemplate.from_template(
|
online_search_conversation = PromptTemplate.from_template(
|
||||||
|
|
|
@ -341,6 +341,35 @@ async def update_search_model(
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@api_config.post("/data/paint/model", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def update_paint_model(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
subscribed = has_required_scope(request, ["premium"])
|
||||||
|
|
||||||
|
if not subscribed:
|
||||||
|
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
|
||||||
|
|
||||||
|
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="set_paint_model",
|
||||||
|
client=client,
|
||||||
|
metadata={"paint_model": new_config.setting.model_name},
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_config is None:
|
||||||
|
return {"status": "error", "message": "Model not found"}
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
@api_config.get("/index/size", response_model=Dict[str, int])
|
@api_config.get("/index/size", response_model=Dict[str, int])
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def get_indexed_data_size(request: Request, common: CommonQueryParams):
|
async def get_indexed_data_size(request: Request, common: CommonQueryParams):
|
||||||
|
|
|
@ -453,12 +453,14 @@ async def generate_better_image_prompt(
|
||||||
location_data: LocationData,
|
location_data: LocationData,
|
||||||
note_references: List[Dict[str, Any]],
|
note_references: List[Dict[str, Any]],
|
||||||
online_results: Optional[dict] = None,
|
online_results: Optional[dict] = None,
|
||||||
|
model_type: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a better image prompt from the given query
|
Generate a better image prompt from the given query
|
||||||
"""
|
"""
|
||||||
|
|
||||||
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
|
||||||
|
|
||||||
if location_data:
|
if location_data:
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
||||||
|
@ -477,21 +479,34 @@ async def generate_better_image_prompt(
|
||||||
elif online_results[result].get("webpages"):
|
elif online_results[result].get("webpages"):
|
||||||
simplified_online_results[result] = online_results[result]["webpages"]
|
simplified_online_results[result] = online_results[result]["webpages"]
|
||||||
|
|
||||||
image_prompt = prompts.image_generation_improve_prompt.format(
|
if model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||||
query=q,
|
image_prompt = prompts.image_generation_improve_prompt_dalle.format(
|
||||||
chat_history=conversation_history,
|
query=q,
|
||||||
location=location_prompt,
|
chat_history=conversation_history,
|
||||||
current_date=today_date,
|
location=location_prompt,
|
||||||
references=user_references,
|
current_date=today_date,
|
||||||
online_results=simplified_online_results,
|
references=user_references,
|
||||||
)
|
online_results=simplified_online_results,
|
||||||
|
)
|
||||||
|
elif model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
||||||
|
image_prompt = prompts.image_generation_improve_prompt_sd.format(
|
||||||
|
query=q,
|
||||||
|
chat_history=conversation_history,
|
||||||
|
location=location_prompt,
|
||||||
|
current_date=today_date,
|
||||||
|
references=user_references,
|
||||||
|
online_results=simplified_online_results,
|
||||||
|
)
|
||||||
|
|
||||||
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
|
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
|
||||||
|
|
||||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||||
response = await send_message_to_model_wrapper(image_prompt, chat_model_option=summarizer_model)
|
response = await send_message_to_model_wrapper(image_prompt, chat_model_option=summarizer_model)
|
||||||
|
response = response.strip()
|
||||||
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
|
response = response[1:-1]
|
||||||
|
|
||||||
return response.strip()
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def send_message_to_model_wrapper(
|
async def send_message_to_model_wrapper(
|
||||||
|
@ -747,74 +762,110 @@ async def text_to_image(
|
||||||
image_url = None
|
image_url = None
|
||||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||||
|
|
||||||
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
|
||||||
if not text_to_image_config:
|
if not text_to_image_config:
|
||||||
# If the user has not configured a text to image model, return an unsupported on server error
|
# If the user has not configured a text to image model, return an unsupported on server error
|
||||||
status_code = 501
|
status_code = 501
|
||||||
message = "Failed to generate image. Setup image generation on the server."
|
message = "Failed to generate image. Setup image generation on the server."
|
||||||
return image_url or image, status_code, message, intent_type.value
|
return image_url or image, status_code, message, intent_type.value
|
||||||
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
|
||||||
logger.info("Generating image with OpenAI")
|
text2image_model = text_to_image_config.model_name
|
||||||
text2image_model = text_to_image_config.model_name
|
chat_history = ""
|
||||||
chat_history = ""
|
for chat in conversation_log.get("chat", [])[-4:]:
|
||||||
for chat in conversation_log.get("chat", [])[-4:]:
|
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
|
||||||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
|
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
chat_history += f"A: {chat['message']}\n"
|
||||||
chat_history += f"A: {chat['message']}\n"
|
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
|
||||||
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
|
chat_history += f"Q: Query: {chat['intent']['query']}\n"
|
||||||
chat_history += f"Q: Query: {chat['intent']['query']}\n"
|
chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n"
|
||||||
chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n"
|
|
||||||
try:
|
with timer("Improve the original user query", logger):
|
||||||
with timer("Improve the original user query", logger):
|
if send_status_func:
|
||||||
if send_status_func:
|
await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
|
||||||
await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
|
improved_image_prompt = await generate_better_image_prompt(
|
||||||
improved_image_prompt = await generate_better_image_prompt(
|
message,
|
||||||
message,
|
chat_history,
|
||||||
chat_history,
|
location_data=location_data,
|
||||||
location_data=location_data,
|
note_references=references,
|
||||||
note_references=references,
|
online_results=online_results,
|
||||||
online_results=online_results,
|
model_type=text_to_image_config.model_type,
|
||||||
)
|
)
|
||||||
with timer("Generate image with OpenAI", logger):
|
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
|
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
|
||||||
|
|
||||||
|
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||||
|
with timer("Generate image with OpenAI", logger):
|
||||||
|
if text_to_image_config.api_key:
|
||||||
|
api_key = text_to_image_config.api_key
|
||||||
|
elif text_to_image_config.openai_config:
|
||||||
|
api_key = text_to_image_config.openai_config.api_key
|
||||||
|
elif state.openai_client:
|
||||||
|
api_key = state.openai_client.api_key
|
||||||
|
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||||
|
try:
|
||||||
response = state.openai_client.images.generate(
|
response = state.openai_client.images.generate(
|
||||||
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
|
prompt=improved_image_prompt,
|
||||||
|
model=text2image_model,
|
||||||
|
response_format="b64_json",
|
||||||
|
extra_headers=auth_header,
|
||||||
)
|
)
|
||||||
image = response.data[0].b64_json
|
image = response.data[0].b64_json
|
||||||
|
|
||||||
with timer("Convert image to webp", logger):
|
|
||||||
# Convert png to webp for faster loading
|
|
||||||
decoded_image = base64.b64decode(image)
|
decoded_image = base64.b64decode(image)
|
||||||
image_io = io.BytesIO(decoded_image)
|
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
||||||
png_image = Image.open(image_io)
|
if "content_policy_violation" in e.message:
|
||||||
webp_image_io = io.BytesIO()
|
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||||
png_image.save(webp_image_io, "WEBP")
|
status_code = e.status_code # type: ignore
|
||||||
webp_image_bytes = webp_image_io.getvalue()
|
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
||||||
webp_image_io.close()
|
return image_url or image, status_code, message, intent_type.value
|
||||||
image_io.close()
|
else:
|
||||||
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
|
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
||||||
|
status_code = e.status_code # type: ignore
|
||||||
|
return image_url or image, status_code, message, intent_type.value
|
||||||
|
|
||||||
with timer("Upload image to S3", logger):
|
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
||||||
image_url = upload_image(webp_image_bytes, user.uuid)
|
with timer("Generate image with Stability AI", logger):
|
||||||
if image_url:
|
try:
|
||||||
intent_type = ImageIntentType.TEXT_TO_IMAGE2
|
response = requests.post(
|
||||||
else:
|
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
|
||||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
|
||||||
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
files={"none": ""},
|
||||||
|
data={
|
||||||
return image_url or image, status_code, improved_image_prompt, intent_type.value
|
"prompt": improved_image_prompt,
|
||||||
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
"model": text2image_model,
|
||||||
if "content_policy_violation" in e.message:
|
"mode": "text-to-image",
|
||||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
"output_format": "png",
|
||||||
status_code = e.status_code # type: ignore
|
"seed": 1032622926,
|
||||||
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
"aspect_ratio": "1:1",
|
||||||
return image_url or image, status_code, message, intent_type.value
|
},
|
||||||
else:
|
)
|
||||||
|
decoded_image = response.content
|
||||||
|
except requests.RequestException as e:
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
message = f"Image generation failed with Stability AI error: {e}"
|
||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
return image_url or image, status_code, message, intent_type.value
|
return image_url or image, status_code, message, intent_type.value
|
||||||
return image_url or image, status_code, response, intent_type.value
|
|
||||||
|
with timer("Convert image to webp", logger):
|
||||||
|
# Convert png to webp for faster loading
|
||||||
|
image_io = io.BytesIO(decoded_image)
|
||||||
|
png_image = Image.open(image_io)
|
||||||
|
webp_image_io = io.BytesIO()
|
||||||
|
png_image.save(webp_image_io, "WEBP")
|
||||||
|
webp_image_bytes = webp_image_io.getvalue()
|
||||||
|
webp_image_io.close()
|
||||||
|
image_io.close()
|
||||||
|
|
||||||
|
with timer("Upload image to S3", logger):
|
||||||
|
image_url = upload_image(webp_image_bytes, user.uuid)
|
||||||
|
if image_url:
|
||||||
|
intent_type = ImageIntentType.TEXT_TO_IMAGE2
|
||||||
|
else:
|
||||||
|
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||||
|
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
||||||
|
|
||||||
|
return image_url or image, status_code, improved_image_prompt, intent_type.value
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
|
|
|
@ -249,6 +249,12 @@ def config_page(request: Request):
|
||||||
|
|
||||||
current_search_model_option = adapters.get_user_search_model_or_default(user)
|
current_search_model_option = adapters.get_user_search_model_or_default(user)
|
||||||
|
|
||||||
|
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
|
||||||
|
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
||||||
|
all_paint_model_options = list()
|
||||||
|
for paint_model in paint_model_options:
|
||||||
|
all_paint_model_options.append({"model_name": paint_model.model_name, "id": paint_model.id})
|
||||||
|
|
||||||
notion_oauth_url = get_notion_auth_url(user)
|
notion_oauth_url = get_notion_auth_url(user)
|
||||||
|
|
||||||
eleven_labs_enabled = is_eleven_labs_enabled()
|
eleven_labs_enabled = is_eleven_labs_enabled()
|
||||||
|
@ -271,10 +277,12 @@ def config_page(request: Request):
|
||||||
"anonymous_mode": state.anonymous_mode,
|
"anonymous_mode": state.anonymous_mode,
|
||||||
"username": user.username,
|
"username": user.username,
|
||||||
"given_name": given_name,
|
"given_name": given_name,
|
||||||
"conversation_options": all_conversation_options,
|
|
||||||
"search_model_options": all_search_model_options,
|
"search_model_options": all_search_model_options,
|
||||||
"selected_search_model_config": current_search_model_option.id,
|
"selected_search_model_config": current_search_model_option.id,
|
||||||
|
"conversation_options": all_conversation_options,
|
||||||
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
|
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
|
||||||
|
"paint_model_options": all_paint_model_options,
|
||||||
|
"selected_paint_model_config": selected_paint_model_config.id if selected_paint_model_config else None,
|
||||||
"user_photo": user_picture,
|
"user_photo": user_picture,
|
||||||
"billing_enabled": state.billing_enabled,
|
"billing_enabled": state.billing_enabled,
|
||||||
"subscription_state": user_subscription_state,
|
"subscription_state": user_subscription_state,
|
||||||
|
|
Loading…
Reference in a new issue