mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Enable using Stable Diffusion 3 for Image Generation via API
This commit is contained in:
parent
d6fe5d9a63
commit
eda33e092f
3 changed files with 54 additions and 9 deletions
|
@ -0,0 +1,24 @@
|
|||
# Generated by Django 4.2.11 on 2024-06-20 19:02
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0047_alter_entry_file_type"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="texttoimagemodelconfig",
|
||||
name="api_key",
|
||||
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="texttoimagemodelconfig",
|
||||
name="model_type",
|
||||
field=models.CharField(
|
||||
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200
|
||||
),
|
||||
),
|
||||
]
|
|
@ -234,9 +234,11 @@ class SearchModelConfig(BaseModel):
|
|||
class TextToImageModelConfig(BaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
STABILITYAI = "stability-ai"
|
||||
|
||||
model_name = models.CharField(max_length=200, default="dall-e-3")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class SpeechToTextModelOptions(BaseModel):
|
||||
|
|
|
@ -753,7 +753,7 @@ async def text_to_image(
|
|||
status_code = 501
|
||||
message = "Failed to generate image. Setup image generation on the server."
|
||||
return image_url or image, status_code, message, intent_type.value
|
||||
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||
elif state.openai_client:
|
||||
logger.info("Generating image with OpenAI")
|
||||
text2image_model = text_to_image_config.model_name
|
||||
chat_history = ""
|
||||
|
@ -775,17 +775,36 @@ async def text_to_image(
|
|||
note_references=references,
|
||||
online_results=online_results,
|
||||
)
|
||||
with timer("Generate image with OpenAI", logger):
|
||||
if send_status_func:
|
||||
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
|
||||
response = state.openai_client.images.generate(
|
||||
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
|
||||
)
|
||||
image = response.data[0].b64_json
|
||||
if send_status_func:
|
||||
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
|
||||
|
||||
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||
with timer("Generate image with OpenAI", logger):
|
||||
response = state.openai_client.images.generate(
|
||||
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
|
||||
)
|
||||
image = response.data[0].b64_json
|
||||
decoded_image = base64.b64decode(image)
|
||||
|
||||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
||||
with timer("Generate image with Stability AI", logger):
|
||||
response = requests.post(
|
||||
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
|
||||
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
|
||||
files={"none": ""},
|
||||
data={
|
||||
"prompt": improved_image_prompt,
|
||||
"model": text2image_model,
|
||||
"mode": "text-to-image",
|
||||
"output_format": "png",
|
||||
"seed": 1032622926,
|
||||
"aspect_ratio": "1:1",
|
||||
},
|
||||
)
|
||||
decoded_image = response.content
|
||||
|
||||
with timer("Convert image to webp", logger):
|
||||
# Convert png to webp for faster loading
|
||||
decoded_image = base64.b64decode(image)
|
||||
image_io = io.BytesIO(decoded_image)
|
||||
png_image = Image.open(image_io)
|
||||
webp_image_io = io.BytesIO()
|
||||
|
|
Loading…
Reference in a new issue