mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-24 07:55:07 +01:00
Enable using Stable Diffusion 3 for Image Generation via API
This commit is contained in:
parent
d6fe5d9a63
commit
eda33e092f
3 changed files with 54 additions and 9 deletions
|
@ -0,0 +1,24 @@
|
||||||
|
# Generated by Django 4.2.11 on 2024-06-20 19:02
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0047_alter_entry_file_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="texttoimagemodelconfig",
|
||||||
|
name="api_key",
|
||||||
|
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="texttoimagemodelconfig",
|
||||||
|
name="model_type",
|
||||||
|
field=models.CharField(
|
||||||
|
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
|
@ -234,9 +234,11 @@ class SearchModelConfig(BaseModel):
|
||||||
class TextToImageModelConfig(BaseModel):
|
class TextToImageModelConfig(BaseModel):
|
||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
|
STABILITYAI = "stability-ai"
|
||||||
|
|
||||||
model_name = models.CharField(max_length=200, default="dall-e-3")
|
model_name = models.CharField(max_length=200, default="dall-e-3")
|
||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||||
|
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextModelOptions(BaseModel):
|
class SpeechToTextModelOptions(BaseModel):
|
||||||
|
|
|
@ -753,7 +753,7 @@ async def text_to_image(
|
||||||
status_code = 501
|
status_code = 501
|
||||||
message = "Failed to generate image. Setup image generation on the server."
|
message = "Failed to generate image. Setup image generation on the server."
|
||||||
return image_url or image, status_code, message, intent_type.value
|
return image_url or image, status_code, message, intent_type.value
|
||||||
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
elif state.openai_client:
|
||||||
logger.info("Generating image with OpenAI")
|
logger.info("Generating image with OpenAI")
|
||||||
text2image_model = text_to_image_config.model_name
|
text2image_model = text_to_image_config.model_name
|
||||||
chat_history = ""
|
chat_history = ""
|
||||||
|
@ -775,17 +775,36 @@ async def text_to_image(
|
||||||
note_references=references,
|
note_references=references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
)
|
)
|
||||||
with timer("Generate image with OpenAI", logger):
|
if send_status_func:
|
||||||
if send_status_func:
|
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
|
||||||
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
|
|
||||||
response = state.openai_client.images.generate(
|
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||||
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
|
with timer("Generate image with OpenAI", logger):
|
||||||
)
|
response = state.openai_client.images.generate(
|
||||||
image = response.data[0].b64_json
|
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
|
||||||
|
)
|
||||||
|
image = response.data[0].b64_json
|
||||||
|
decoded_image = base64.b64decode(image)
|
||||||
|
|
||||||
|
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
||||||
|
with timer("Generate image with Stability AI", logger):
|
||||||
|
response = requests.post(
|
||||||
|
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
|
||||||
|
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
|
||||||
|
files={"none": ""},
|
||||||
|
data={
|
||||||
|
"prompt": improved_image_prompt,
|
||||||
|
"model": text2image_model,
|
||||||
|
"mode": "text-to-image",
|
||||||
|
"output_format": "png",
|
||||||
|
"seed": 1032622926,
|
||||||
|
"aspect_ratio": "1:1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
decoded_image = response.content
|
||||||
|
|
||||||
with timer("Convert image to webp", logger):
|
with timer("Convert image to webp", logger):
|
||||||
# Convert png to webp for faster loading
|
# Convert png to webp for faster loading
|
||||||
decoded_image = base64.b64decode(image)
|
|
||||||
image_io = io.BytesIO(decoded_image)
|
image_io = io.BytesIO(decoded_image)
|
||||||
png_image = Image.open(image_io)
|
png_image = Image.open(image_io)
|
||||||
webp_image_io = io.BytesIO()
|
webp_image_io = io.BytesIO()
|
||||||
|
|
Loading…
Reference in a new issue