mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Support using image generation models like Flux via Replicate
Enables using any image generation model on Replicate's Predictions API endpoints. The server admin just needs to add text-to-image model on the server/admin panel in organization/model_name format and input their Replicate API key with it Create db migration (including merge)
This commit is contained in:
parent
1d512b4986
commit
1b82aea753
4 changed files with 94 additions and 6 deletions
|
@ -0,0 +1,21 @@
|
|||
# Generated by Django 5.0.7 on 2024-09-12 05:43
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0060_merge_20240905_1828"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="texttoimagemodelconfig",
|
||||
name="model_type",
|
||||
field=models.CharField(
|
||||
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai"), ("replicate", "Replicate")],
|
||||
default="openai",
|
||||
max_length=200,
|
||||
),
|
||||
),
|
||||
]
|
14
src/khoj/database/migrations/0062_merge_20240913_0222.py
Normal file
14
src/khoj/database/migrations/0062_merge_20240913_0222.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
# Generated by Django 5.0.8 on 2024-09-13 02:22
|
||||
|
||||
from typing import List
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0061_alter_chatmodeloptions_model_type"),
|
||||
("database", "0061_alter_texttoimagemodelconfig_model_type"),
|
||||
]
|
||||
|
||||
operations: List[str] = []
|
|
@ -280,6 +280,7 @@ class TextToImageModelConfig(BaseModel):
|
|||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
STABILITYAI = "stability-ai"
|
||||
REPLICATE = "replicate"
|
||||
|
||||
model_name = models.CharField(max_length=200, default="dall-e-3")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||
|
|
|
@ -7,6 +7,7 @@ import logging
|
|||
import math
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
|
@ -568,7 +569,7 @@ async def generate_better_image_prompt(
|
|||
references=user_references,
|
||||
online_results=simplified_online_results,
|
||||
)
|
||||
elif model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
||||
elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]:
|
||||
image_prompt = prompts.image_generation_improve_prompt_sd.format(
|
||||
query=q,
|
||||
chat_history=conversation_history,
|
||||
|
@ -991,7 +992,8 @@ async def text_to_image(
|
|||
extra_headers=auth_header,
|
||||
)
|
||||
image = response.data[0].b64_json
|
||||
decoded_image = base64.b64decode(image)
|
||||
# Decode base64 png and convert it to webp for faster loading
|
||||
webp_image_bytes = convert_image_to_webp(base64.b64decode(image))
|
||||
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
||||
if "content_policy_violation" in e.message:
|
||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||
|
@ -1021,7 +1023,8 @@ async def text_to_image(
|
|||
"aspect_ratio": "1:1",
|
||||
},
|
||||
)
|
||||
decoded_image = response.content
|
||||
# Convert png to webp for faster loading
|
||||
webp_image_bytes = convert_image_to_webp(response.content)
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||
message = f"Image generation failed with Stability AI error: {e}"
|
||||
|
@ -1029,9 +1032,58 @@ async def text_to_image(
|
|||
yield image_url or image, status_code, message, intent_type.value
|
||||
return
|
||||
|
||||
with timer("Convert image to webp", logger):
|
||||
# Convert png to webp for faster loading
|
||||
webp_image_bytes = convert_image_to_webp(decoded_image)
|
||||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
|
||||
with timer("Generate image using Replicate", logger):
|
||||
try:
|
||||
# Create image generation task on Replicate
|
||||
create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {text_to_image_config.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
json = {
|
||||
"input": {
|
||||
"prompt": improved_image_prompt,
|
||||
"num_outputs": 1,
|
||||
"aspect_ratio": "1:1",
|
||||
"output_format": "webp",
|
||||
"output_quality": 100,
|
||||
}
|
||||
}
|
||||
create_prediction = requests.post(create_prediction_url, headers=headers, json=json).json()
|
||||
|
||||
# Get status of image generation task
|
||||
get_prediction_url = create_prediction["urls"]["get"]
|
||||
get_prediction = requests.get(get_prediction_url, headers=headers).json()
|
||||
status = get_prediction["status"]
|
||||
retry_count = 1
|
||||
|
||||
# Poll the image generation task for completion status
|
||||
while status not in ["succeeded", "failed", "canceled"] and retry_count < 20:
|
||||
time.sleep(2)
|
||||
get_prediction = requests.get(get_prediction_url, headers=headers).json()
|
||||
status = get_prediction["status"]
|
||||
retry_count += 1
|
||||
|
||||
# Raise exception if the image generation task fails
|
||||
if status != "succeeded":
|
||||
if retry_count >= 10:
|
||||
raise requests.RequestException("Image generation timed out")
|
||||
raise requests.RequestException(f"Image generation failed with status: {status}")
|
||||
|
||||
# Get the generated image
|
||||
image_url = (
|
||||
get_prediction["output"][0]
|
||||
if isinstance(get_prediction["output"], list)
|
||||
else get_prediction["output"]
|
||||
)
|
||||
webp_image_bytes = io.BytesIO(requests.get(image_url).content).getvalue()
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||
message = f"Image generation for {text2image_model} failed with Replicate API error: {e}"
|
||||
status_code = 500
|
||||
yield image_url or image, status_code, message, intent_type.value
|
||||
return
|
||||
|
||||
with timer("Upload image to S3", logger):
|
||||
image_url = upload_image(webp_image_bytes, user.uuid)
|
||||
|
|
Loading…
Add table
Reference in a new issue