Extract image generation code into new image processor for modularity

This commit is contained in:
Debanjum Singh Solanky 2024-09-12 19:13:09 -07:00
parent 84051d7d89
commit 75d3b34452
3 changed files with 214 additions and 186 deletions

View file

@ -0,0 +1,212 @@
import base64
import io
import logging
import time
from typing import Any, Callable, Dict, List, Optional
import openai
import requests
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import KhojUser, TextToImageModelConfig
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
from khoj.routers.storage import upload_image
from khoj.utils import state
from khoj.utils.helpers import ImageIntentType, convert_image_to_webp, timer
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
async def text_to_image(
message: str,
user: KhojUser,
conversation_log: dict,
location_data: LocationData,
references: List[Dict[str, Any]],
online_results: Dict[str, Any],
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
):
status_code = 200
image = None
image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
yield image_url or image, status_code, message, intent_type.value
return
text2image_model = text_to_image_config.model_name
chat_history = ""
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
if send_status_func:
async for event in send_status_func("**Enhancing the Painting Prompt**"):
yield {ChatEvent.STATUS: event}
# Generate a better image prompt
# Use the user's message, chat history, and other context
image_prompt = await generate_better_image_prompt(
message,
chat_history,
location_data=location_data,
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
)
if send_status_func:
async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
yield {ChatEvent.STATUS: event}
# Generate image using the configured model and API
with timer(f"Generate image with {text_to_image_config.model_type}", logger):
try:
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
webp_image_bytes = generate_image_with_openai(image_prompt, text_to_image_config, text2image_model)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed with error: {e}"
status_code = 502
yield image_url or image, status_code, message, intent_type.value
return
# Decide how to store the generated image
with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid)
if image_url:
intent_type = ImageIntentType.TEXT_TO_IMAGE2
else:
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8")
yield image_url or image, status_code, image_prompt, intent_type.value
def generate_image_with_openai(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
):
"Generate image using OpenAI API"
# Get the API key from the user's configuration
if text_to_image_config.api_key:
api_key = text_to_image_config.api_key
elif text_to_image_config.openai_config:
api_key = text_to_image_config.openai_config.api_key
elif state.openai_client:
api_key = state.openai_client.api_key
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
# Generate image using OpenAI API
OPENAI_IMAGE_GEN_STYLE = "vivid"
response = state.openai_client.images.generate(
prompt=improved_image_prompt,
model=text2image_model,
style=OPENAI_IMAGE_GEN_STYLE,
response_format="b64_json",
extra_headers=auth_header,
)
# Extract the base64 image from the response
image = response.data[0].b64_json
# Decode base64 png and convert it to webp for faster loading
return convert_image_to_webp(base64.b64decode(image))
def generate_image_with_stability(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
):
"Generate image using Stability AI"
# Call Stability AI API to generate image
response = requests.post(
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
files={"none": ""},
data={
"prompt": improved_image_prompt,
"model": text2image_model,
"mode": "text-to-image",
"output_format": "png",
"aspect_ratio": "1:1",
},
)
# Convert png to webp for faster loading
return convert_image_to_webp(response.content)
def generate_image_with_replicate(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
):
"Generate image using Replicate API"
# Create image generation task on Replicate
replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
headers = {
"Authorization": f"Bearer {text_to_image_config.api_key}",
"Content-Type": "application/json",
}
json = {
"input": {
"prompt": improved_image_prompt,
"num_outputs": 1,
"aspect_ratio": "1:1",
"output_format": "webp",
"output_quality": 100,
}
}
create_prediction = requests.post(replicate_create_prediction_url, headers=headers, json=json).json()
# Get status of image generation task
get_prediction_url = create_prediction["urls"]["get"]
get_prediction = requests.get(get_prediction_url, headers=headers).json()
status = get_prediction["status"]
retry_count = 1
# Poll the image generation task for completion status
while status not in ["succeeded", "failed", "canceled"] and retry_count < 20:
time.sleep(2)
get_prediction = requests.get(get_prediction_url, headers=headers).json()
status = get_prediction["status"]
retry_count += 1
# Raise exception if the image generation task fails
if status != "succeeded":
if retry_count >= 10:
raise requests.RequestException("Image generation timed out")
raise requests.RequestException(f"Image generation failed with status: {status}")
# Get the generated image
image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"]
return io.BytesIO(requests.get(image_url).content).getvalue()

View file

@ -26,6 +26,7 @@ from khoj.database.adapters import (
from khoj.database.models import KhojUser
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.image.generate import text_to_image
from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.routers.api import extract_references_and_questions
@ -44,7 +45,6 @@ from khoj.routers.helpers import (
is_query_empty,
is_ready_to_chat,
read_chat_stream,
text_to_image,
update_telemetry_state,
validate_conversation_config,
)

View file

@ -1,13 +1,10 @@
import asyncio
import base64
import hashlib
import io
import json
import logging
import math
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
from enum import Enum
@ -17,7 +14,6 @@ from typing import (
Annotated,
Any,
AsyncGenerator,
Callable,
Dict,
Iterator,
List,
@ -25,17 +21,15 @@ from typing import (
Tuple,
Union,
)
from urllib.parse import parse_qs, urlencode, urljoin, urlparse
from urllib.parse import parse_qs, urljoin, urlparse
import cron_descriptor
import openai
import pytz
import requests
from apscheduler.job import Job
from apscheduler.triggers.cron import CronTrigger
from asgiref.sync import sync_to_async
from fastapi import Depends, Header, HTTPException, Request, UploadFile
from PIL import Image
from starlette.authentication import has_required_scope
from starlette.requests import URL
@ -94,7 +88,6 @@ from khoj.processor.conversation.utils import (
)
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
from khoj.routers.email import is_resend_enabled, send_task_email
from khoj.routers.storage import upload_image
from khoj.routers.twilio import is_twilio_enabled
from khoj.search_type import text_search
from khoj.utils import state
@ -102,8 +95,6 @@ from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import (
LRU,
ConversationCommand,
ImageIntentType,
convert_image_to_webp,
is_none_or_empty,
is_valid_url,
log_telemetry,
@ -922,181 +913,6 @@ def generate_chat_response(
return chat_response, metadata
async def text_to_image(
message: str,
user: KhojUser,
conversation_log: dict,
location_data: LocationData,
references: List[Dict[str, Any]],
online_results: Dict[str, Any],
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
):
status_code = 200
image = None
response = None
image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
yield image_url or image, status_code, message, intent_type.value
return
text2image_model = text_to_image_config.model_name
chat_history = ""
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
if send_status_func:
async for event in send_status_func("**Enhancing the Painting Prompt**"):
yield {ChatEvent.STATUS: event}
improved_image_prompt = await generate_better_image_prompt(
message,
chat_history,
location_data=location_data,
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
)
if send_status_func:
async for event in send_status_func(f"**Painting to Imagine**:\n{improved_image_prompt}"):
yield {ChatEvent.STATUS: event}
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
with timer("Generate image with OpenAI", logger):
if text_to_image_config.api_key:
api_key = text_to_image_config.api_key
elif text_to_image_config.openai_config:
api_key = text_to_image_config.openai_config.api_key
elif state.openai_client:
api_key = state.openai_client.api_key
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
try:
response = state.openai_client.images.generate(
prompt=improved_image_prompt,
model=text2image_model,
style="vivid",
response_format="b64_json",
extra_headers=auth_header,
)
image = response.data[0].b64_json
# Decode base64 png and convert it to webp for faster loading
webp_image_bytes = convert_image_to_webp(base64.b64decode(image))
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
with timer("Generate image with Stability AI", logger):
try:
response = requests.post(
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
files={"none": ""},
data={
"prompt": improved_image_prompt,
"model": text2image_model,
"mode": "text-to-image",
"output_format": "png",
"aspect_ratio": "1:1",
},
)
# Convert png to webp for faster loading
webp_image_bytes = convert_image_to_webp(response.content)
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with Stability AI error: {e}"
status_code = e.status_code # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
with timer("Generate image using Replicate", logger):
try:
# Create image generation task on Replicate
create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
headers = {
"Authorization": f"Bearer {text_to_image_config.api_key}",
"Content-Type": "application/json",
}
json = {
"input": {
"prompt": improved_image_prompt,
"num_outputs": 1,
"aspect_ratio": "1:1",
"output_format": "webp",
"output_quality": 100,
}
}
create_prediction = requests.post(create_prediction_url, headers=headers, json=json).json()
# Get status of image generation task
get_prediction_url = create_prediction["urls"]["get"]
get_prediction = requests.get(get_prediction_url, headers=headers).json()
status = get_prediction["status"]
retry_count = 1
# Poll the image generation task for completion status
while status not in ["succeeded", "failed", "canceled"] and retry_count < 20:
time.sleep(2)
get_prediction = requests.get(get_prediction_url, headers=headers).json()
status = get_prediction["status"]
retry_count += 1
# Raise exception if the image generation task fails
if status != "succeeded":
if retry_count >= 10:
raise requests.RequestException("Image generation timed out")
raise requests.RequestException(f"Image generation failed with status: {status}")
# Get the generated image
image_url = (
get_prediction["output"][0]
if isinstance(get_prediction["output"], list)
else get_prediction["output"]
)
webp_image_bytes = io.BytesIO(requests.get(image_url).content).getvalue()
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation for {text2image_model} failed with Replicate API error: {e}"
status_code = 500
yield image_url or image, status_code, message, intent_type.value
return
with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid)
if image_url:
intent_type = ImageIntentType.TEXT_TO_IMAGE2
else:
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8")
yield image_url or image, status_code, improved_image_prompt, intent_type.value
class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
self.requests = requests