Support using image generation models like Flux via Replicate

Enables using any image generation model on Replicate's Predictions
API endpoints.

The server admin just needs to add text-to-image model on the
server/admin panel in organization/model_name format and input their
Replicate API key with it

Create db migration (including merge)
This commit is contained in:
Debanjum Singh Solanky 2024-09-12 00:29:04 -07:00
parent 1d512b4986
commit 1b82aea753
4 changed files with 94 additions and 6 deletions

View file

@ -0,0 +1,21 @@
# Generated by Django 5.0.7 on 2024-09-12 05:43
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0060_merge_20240905_1828"),
]
operations = [
migrations.AlterField(
model_name="texttoimagemodelconfig",
name="model_type",
field=models.CharField(
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai"), ("replicate", "Replicate")],
default="openai",
max_length=200,
),
),
]

View file

@ -0,0 +1,14 @@
# Generated by Django 5.0.8 on 2024-09-13 02:22
from typing import List
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0061_alter_chatmodeloptions_model_type"),
("database", "0061_alter_texttoimagemodelconfig_model_type"),
]
operations: List[str] = []

View file

@ -280,6 +280,7 @@ class TextToImageModelConfig(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
STABILITYAI = "stability-ai"
REPLICATE = "replicate"
model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)

View file

@ -7,6 +7,7 @@ import logging
import math
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
from enum import Enum
@ -568,7 +569,7 @@ async def generate_better_image_prompt(
references=user_references,
online_results=simplified_online_results,
)
elif model_type == TextToImageModelConfig.ModelType.STABILITYAI:
elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]:
image_prompt = prompts.image_generation_improve_prompt_sd.format(
query=q,
chat_history=conversation_history,
@ -991,7 +992,8 @@ async def text_to_image(
extra_headers=auth_header,
)
image = response.data[0].b64_json
decoded_image = base64.b64decode(image)
# Decode base64 png and convert it to webp for faster loading
webp_image_bytes = convert_image_to_webp(base64.b64decode(image))
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
@ -1021,7 +1023,8 @@ async def text_to_image(
"aspect_ratio": "1:1",
},
)
decoded_image = response.content
# Convert png to webp for faster loading
webp_image_bytes = convert_image_to_webp(response.content)
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with Stability AI error: {e}"
@ -1029,9 +1032,58 @@ async def text_to_image(
yield image_url or image, status_code, message, intent_type.value
return
with timer("Convert image to webp", logger):
# Convert png to webp for faster loading
webp_image_bytes = convert_image_to_webp(decoded_image)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
with timer("Generate image using Replicate", logger):
try:
# Create image generation task on Replicate
create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
headers = {
"Authorization": f"Bearer {text_to_image_config.api_key}",
"Content-Type": "application/json",
}
json = {
"input": {
"prompt": improved_image_prompt,
"num_outputs": 1,
"aspect_ratio": "1:1",
"output_format": "webp",
"output_quality": 100,
}
}
create_prediction = requests.post(create_prediction_url, headers=headers, json=json).json()
# Get status of image generation task
get_prediction_url = create_prediction["urls"]["get"]
get_prediction = requests.get(get_prediction_url, headers=headers).json()
status = get_prediction["status"]
retry_count = 1
# Poll the image generation task for completion status
while status not in ["succeeded", "failed", "canceled"] and retry_count < 20:
time.sleep(2)
get_prediction = requests.get(get_prediction_url, headers=headers).json()
status = get_prediction["status"]
retry_count += 1
# Raise exception if the image generation task fails
if status != "succeeded":
if retry_count >= 10:
raise requests.RequestException("Image generation timed out")
raise requests.RequestException(f"Image generation failed with status: {status}")
# Get the generated image
image_url = (
get_prediction["output"][0]
if isinstance(get_prediction["output"], list)
else get_prediction["output"]
)
webp_image_bytes = io.BytesIO(requests.get(image_url).content).getvalue()
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation for {text2image_model} failed with Replicate API error: {e}"
status_code = 500
yield image_url or image, status_code, message, intent_type.value
return
with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid)