Enable using Stable Diffusion 3 for Image Generation via API (#830)

- Support Stable Diffusion 3 via API
  Server Admin needs to setup model similar to DALLE-3 via Django Admin Panel
- Use shorter prompt generator to prompt SD3 to create better images
- Allow users to set paint model to use from web client config page
This commit is contained in:
Debanjum 2024-07-02 17:28:50 +05:30 committed by GitHub
commit 826c3dc9cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 375 additions and 70 deletions

View file

@ -48,6 +48,7 @@ from khoj.database.models import (
UserConversationConfig,
UserRequests,
UserSearchModelConfig,
UserTextToImageModelConfig,
UserVoiceModelConfig,
VoiceModelOption,
)
@ -905,6 +906,34 @@ class ConversationAdapters:
async def aget_text_to_image_model_config():
return await TextToImageModelConfig.objects.filter().afirst()
@staticmethod
def get_text_to_image_model_options():
return TextToImageModelConfig.objects.all()
@staticmethod
def get_user_text_to_image_model_config(user: KhojUser):
config = UserTextToImageModelConfig.objects.filter(user=user).first()
if not config:
return None
return config.setting
@staticmethod
async def aget_user_text_to_image_model(user: KhojUser):
config = await UserTextToImageModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
return None
return config.setting
@staticmethod
async def aset_user_text_to_image_model(user: KhojUser, text_to_image_model_config_id: int):
config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst()
if not config:
return None
new_config, _ = await UserTextToImageModelConfig.objects.aupdate_or_create(
user=user, defaults={"setting": config}
)
return new_config
class FileObjectAdapters:
@staticmethod

View file

@ -96,7 +96,6 @@ admin.site.register(SpeechToTextModelOptions)
admin.site.register(SearchModelConfig)
admin.site.register(ReflectiveQuestion)
admin.site.register(UserSearchModelConfig)
admin.site.register(TextToImageModelConfig)
admin.site.register(ClientApplication)
admin.site.register(GithubConfig)
admin.site.register(NotionConfig)
@ -153,6 +152,16 @@ class ChatModelOptionsAdmin(admin.ModelAdmin):
search_fields = ("id", "chat_model", "model_type")
@admin.register(TextToImageModelConfig)
class TextToImageModelOptionsAdmin(admin.ModelAdmin):
list_display = (
"id",
"model_name",
"model_type",
)
search_fields = ("id", "model_name", "model_type")
@admin.register(OpenAIProcessorConversationConfig)
class OpenAIProcessorConversationConfigAdmin(admin.ModelAdmin):
list_display = (

View file

@ -0,0 +1,58 @@
# Generated by Django 4.2.11 on 2024-06-26 03:27
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0048_voicemodeloption_uservoicemodelconfig"),
]
operations = [
migrations.AddField(
model_name="texttoimagemodelconfig",
name="api_key",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
migrations.AddField(
model_name="texttoimagemodelconfig",
name="openai_config",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="database.openaiprocessorconversationconfig",
),
),
migrations.AlterField(
model_name="texttoimagemodelconfig",
name="model_type",
field=models.CharField(
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200
),
),
migrations.CreateModel(
name="UserTextToImageModelConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"setting",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="database.texttoimagemodelconfig"
),
),
(
"user",
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
),
],
options={
"abstract": False,
},
),
]

View file

@ -235,9 +235,37 @@ class SearchModelConfig(BaseModel):
class TextToImageModelConfig(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
STABILITYAI = "stability-ai"
model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
openai_config = models.ForeignKey(
OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True
)
def clean(self):
# Custom validation logic
error = {}
if self.model_type == self.ModelType.OPENAI:
if self.api_key and self.openai_config:
error[
"api_key"
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
error[
"openai_config"
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
if self.model_type != self.ModelType.OPENAI:
if not self.api_key:
error["api_key"] = "The API key field must be set for non OpenAI models."
if self.openai_config:
error["openai_config"] = "OpenAI config cannot be set for non OpenAI models."
if error:
raise ValidationError(error)
def save(self, *args, **kwargs):
self.clean()
super().save(*args, **kwargs)
class SpeechToTextModelOptions(BaseModel):
@ -264,6 +292,11 @@ class UserSearchModelConfig(BaseModel):
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
class UserTextToImageModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
class Conversation(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)

View file

@ -332,6 +332,7 @@
margin: 20px;
}
select#paint-models,
select#search-models,
select#voice-models,
select#chat-models {

View file

@ -192,11 +192,37 @@
</div>
<div class="card-action-row">
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
<button id="save-model" class="card-button happy" onclick="updateChatModel()">
<button id="save-chat-model" class="card-button happy" onclick="updateChatModel()">
Save
</button>
{% else %}
<button id="save-model" class="card-button" disabled>
<button id="save-chat-model" class="card-button" disabled>
Subscribe to use different models
</button>
{% endif %}
</div>
</div>
<div class="card">
<div class="card-title-row">
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
<h3 class="card-title">
<span>Paint</span>
</h3>
</div>
<div class="card-description-row">
<select id="paint-models">
{% for option in paint_model_options %}
<option value="{{ option.id }}" {% if option.id == selected_paint_model_config %}selected{% endif %}>{{ option.model_name }}</option>
{% endfor %}
</select>
</div>
<div class="card-action-row">
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
<button id="save-paint-model" class="card-button happy" onclick="updatePaintModel()">
Save
</button>
{% else %}
<button id="save-paint-model" class="card-button" disabled>
Subscribe to use different models
</button>
{% endif %}
@ -425,7 +451,7 @@
function updateChatModel() {
const chatModel = document.getElementById("chat-models").value;
const saveModelButton = document.getElementById("save-model");
const saveModelButton = document.getElementById("save-chat-model");
saveModelButton.disabled = true;
saveModelButton.innerHTML = "Saving...";
@ -491,6 +517,38 @@
})
};
function updatePaintModel() {
const paintModel = document.getElementById("paint-models").value;
const saveModelButton = document.getElementById("save-paint-model");
saveModelButton.disabled = true;
saveModelButton.innerHTML = "Saving...";
fetch('/api/config/data/paint/model?id=' + paintModel, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
}
})
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
saveModelButton.innerHTML = "Save";
saveModelButton.disabled = false;
let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "Paint model has been updated!";
notificationBanner.style.display = "block";
setTimeout(function() {
notificationBanner.style.display = "none";
}, 5000);
} else {
saveModelButton.innerHTML = "Error";
saveModelButton.disabled = false;
}
})
};
function clearContentType(content_source) {
fetch('/api/config/data/content-source/' + content_source, {
method: 'DELETE',

View file

@ -121,7 +121,7 @@ User's Notes:
## Image Generation
## --
image_generation_improve_prompt = PromptTemplate.from_template(
image_generation_improve_prompt_dalle = PromptTemplate.from_template(
"""
You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. You are provided with the following information to help you generate the prompt:
@ -143,6 +143,35 @@ Remember, now you are generating a prompt to improve the image generation. Add a
Improved Query:"""
)
image_generation_improve_prompt_sd = PromptTemplate.from_template(
"""
You are a talented creator. Write 2-5 sentences with precise image composition, position details to create an image.
Use the provided context below to add specific, fine details to the image composition.
Retain any important information and follow any instructions from the original prompt.
Put any text to be rendered in the image within double quotes in your improved prompt.
You are provided with the following context to help enhance the original prompt:
Today's Date: {current_date}
User's Location: {location}
User's Notes:
{references}
Online References:
{online_results}
Conversation Log:
{chat_history}
Original Prompt: "{query}"
Now create an improved prompt using the context provided above to generate an image.
Retain any important information and follow any instructions from the original prompt.
Use the additional context from the user's notes, online references and conversation log to improve the image generation.
Improved Prompt:"""
)
## Online Search Conversation
## --
online_search_conversation = PromptTemplate.from_template(

View file

@ -341,6 +341,35 @@ async def update_search_model(
return {"status": "ok"}
@api_config.post("/data/paint/model", status_code=200)
@requires(["authenticated"])
async def update_paint_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
if not subscribed:
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_paint_model",
client=client,
metadata={"paint_model": new_config.setting.model_name},
)
if new_config is None:
return {"status": "error", "message": "Model not found"}
return {"status": "ok"}
@api_config.get("/index/size", response_model=Dict[str, int])
@requires(["authenticated"])
async def get_indexed_data_size(request: Request, common: CommonQueryParams):

View file

@ -453,12 +453,14 @@ async def generate_better_image_prompt(
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
model_type: Optional[str] = None,
) -> str:
"""
Generate a better image prompt from the given query
"""
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
@ -477,7 +479,17 @@ async def generate_better_image_prompt(
elif online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"]
image_prompt = prompts.image_generation_improve_prompt.format(
if model_type == TextToImageModelConfig.ModelType.OPENAI:
image_prompt = prompts.image_generation_improve_prompt_dalle.format(
query=q,
chat_history=conversation_history,
location=location_prompt,
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
)
elif model_type == TextToImageModelConfig.ModelType.STABILITYAI:
image_prompt = prompts.image_generation_improve_prompt_sd.format(
query=q,
chat_history=conversation_history,
location=location_prompt,
@ -490,8 +502,11 @@ async def generate_better_image_prompt(
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(image_prompt, chat_model_option=summarizer_model)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
return response.strip()
return response
async def send_message_to_model_wrapper(
@ -747,14 +762,13 @@ async def text_to_image(
image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
return image_url or image, status_code, message, intent_type.value
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
logger.info("Generating image with OpenAI")
text2image_model = text_to_image_config.model_name
chat_history = ""
for chat in conversation_log.get("chat", [])[-4:]:
@ -764,7 +778,7 @@ async def text_to_image(
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
chat_history += f"Q: Query: {chat['intent']['query']}\n"
chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n"
try:
with timer("Improve the original user query", logger):
if send_status_func:
await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
@ -774,18 +788,67 @@ async def text_to_image(
location_data=location_data,
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
)
with timer("Generate image with OpenAI", logger):
if send_status_func:
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
with timer("Generate image with OpenAI", logger):
if text_to_image_config.api_key:
api_key = text_to_image_config.api_key
elif text_to_image_config.openai_config:
api_key = text_to_image_config.openai_config.api_key
elif state.openai_client:
api_key = state.openai_client.api_key
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
try:
response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
prompt=improved_image_prompt,
model=text2image_model,
response_format="b64_json",
extra_headers=auth_header,
)
image = response.data[0].b64_json
decoded_image = base64.b64decode(image)
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
return image_url or image, status_code, message, intent_type.value
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
with timer("Generate image with Stability AI", logger):
try:
response = requests.post(
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
files={"none": ""},
data={
"prompt": improved_image_prompt,
"model": text2image_model,
"mode": "text-to-image",
"output_format": "png",
"seed": 1032622926,
"aspect_ratio": "1:1",
},
)
decoded_image = response.content
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with Stability AI error: {e}"
status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value
with timer("Convert image to webp", logger):
# Convert png to webp for faster loading
decoded_image = base64.b64decode(image)
image_io = io.BytesIO(decoded_image)
png_image = Image.open(image_io)
webp_image_io = io.BytesIO()
@ -803,18 +866,6 @@ async def text_to_image(
image = base64.b64encode(webp_image_bytes).decode("utf-8")
return image_url or image, status_code, improved_image_prompt, intent_type.value
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
return image_url or image, status_code, message, intent_type.value
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value
return image_url or image, status_code, response, intent_type.value
class ApiUserRateLimiter:

View file

@ -249,6 +249,12 @@ def config_page(request: Request):
current_search_model_option = adapters.get_user_search_model_or_default(user)
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
all_paint_model_options = list()
for paint_model in paint_model_options:
all_paint_model_options.append({"model_name": paint_model.model_name, "id": paint_model.id})
notion_oauth_url = get_notion_auth_url(user)
eleven_labs_enabled = is_eleven_labs_enabled()
@ -271,10 +277,12 @@ def config_page(request: Request):
"anonymous_mode": state.anonymous_mode,
"username": user.username,
"given_name": given_name,
"conversation_options": all_conversation_options,
"search_model_options": all_search_model_options,
"selected_search_model_config": current_search_model_option.id,
"conversation_options": all_conversation_options,
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
"paint_model_options": all_paint_model_options,
"selected_paint_model_config": selected_paint_model_config.id if selected_paint_model_config else None,
"user_photo": user_picture,
"billing_enabled": state.billing_enabled,
"subscription_state": user_subscription_state,