From eda33e092fd66adf33666569f31b2f1a863fe723 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 25 Apr 2024 19:12:27 +0530 Subject: [PATCH] Enable using Stable Diffusion 3 for Image Generation via API --- ...texttoimagemodelconfig_api_key_and_more.py | 24 ++++++++++++ src/khoj/database/models/__init__.py | 2 + src/khoj/routers/helpers.py | 37 ++++++++++++++----- 3 files changed, 54 insertions(+), 9 deletions(-) create mode 100644 src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py diff --git a/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py b/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py new file mode 100644 index 00000000..069b16ca --- /dev/null +++ b/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py @@ -0,0 +1,24 @@ +# Generated by Django 4.2.11 on 2024-06-20 19:02 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0047_alter_entry_file_type"), + ] + + operations = [ + migrations.AddField( + model_name="texttoimagemodelconfig", + name="api_key", + field=models.CharField(blank=True, default=None, max_length=200, null=True), + ), + migrations.AlterField( + model_name="texttoimagemodelconfig", + name="model_type", + field=models.CharField( + choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200 + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 00a594ca..89f9802b 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -234,9 +234,11 @@ class SearchModelConfig(BaseModel): class TextToImageModelConfig(BaseModel): class ModelType(models.TextChoices): OPENAI = "openai" + STABILITYAI = "stability-ai" model_name = models.CharField(max_length=200, default="dall-e-3") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) + api_key = models.CharField(max_length=200, default=None, null=True, blank=True) class SpeechToTextModelOptions(BaseModel): diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 90bfd8c6..8cc40f80 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -753,7 +753,7 @@ async def text_to_image( status_code = 501 message = "Failed to generate image. Setup image generation on the server." return image_url or image, status_code, message, intent_type.value - elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: + elif state.openai_client: logger.info("Generating image with OpenAI") text2image_model = text_to_image_config.model_name chat_history = "" @@ -775,17 +775,36 @@ async def text_to_image( note_references=references, online_results=online_results, ) - with timer("Generate image with OpenAI", logger): - if send_status_func: - await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}") - response = state.openai_client.images.generate( - prompt=improved_image_prompt, model=text2image_model, response_format="b64_json" - ) - image = response.data[0].b64_json + if send_status_func: + await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}") + + if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: + with timer("Generate image with OpenAI", logger): + response = state.openai_client.images.generate( + prompt=improved_image_prompt, model=text2image_model, response_format="b64_json" + ) + image = response.data[0].b64_json + decoded_image = base64.b64decode(image) + + elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI: + with timer("Generate image with Stability AI", logger): + response = requests.post( + f"https://api.stability.ai/v2beta/stable-image/generate/sd3", + headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"}, + files={"none": ""}, + data={ + "prompt": improved_image_prompt, + "model": text2image_model, + "mode": "text-to-image", + "output_format": "png", + "seed": 1032622926, + "aspect_ratio": "1:1", + }, + ) + decoded_image = response.content with timer("Convert image to webp", logger): # Convert png to webp for faster loading - decoded_image = base64.b64decode(image) image_io = io.BytesIO(decoded_image) png_image = Image.open(image_io) webp_image_io = io.BytesIO()