Do not require OpenAI to generate image as local chat + sd3 works now

Previously the text_to_image helper would only trigger the image
generation flow if OpenAI client was setup. This is not required
anymore as offline chat model + sd3 API works. So remove that check
This commit is contained in:
Debanjum Singh Solanky 2024-06-21 02:08:02 +05:30
parent 2c4bf91a61
commit 1acf969c6e

View file

@ -768,8 +768,7 @@ async def text_to_image(
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
return image_url or image, status_code, message, intent_type.value
elif state.openai_client:
logger.info("Generating image with OpenAI")
text2image_model = text_to_image_config.model_name
chat_history = ""
for chat in conversation_log.get("chat", [])[-4:]:
@ -779,7 +778,7 @@ async def text_to_image(
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
chat_history += f"Q: Query: {chat['intent']['query']}\n"
chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n"
try:
with timer("Improve the original user query", logger):
if send_status_func:
await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
@ -791,19 +790,33 @@ async def text_to_image(
online_results=online_results,
model_type=text_to_image_config.model_type,
)
if send_status_func:
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
with timer("Generate image with OpenAI", logger):
try:
response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
)
image = response.data[0].b64_json
decoded_image = base64.b64decode(image)
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
return image_url or image, status_code, message, intent_type.value
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
with timer("Generate image with Stability AI", logger):
try:
response = requests.post(
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
@ -818,6 +831,11 @@ async def text_to_image(
},
)
decoded_image = response.content
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with Stability AI error: {e}"
status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value
with timer("Convert image to webp", logger):
# Convert png to webp for faster loading
@ -838,18 +856,6 @@ async def text_to_image(
image = base64.b64encode(webp_image_bytes).decode("utf-8")
return image_url or image, status_code, improved_image_prompt, intent_type.value
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
return image_url or image, status_code, message, intent_type.value
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value
return image_url or image, status_code, response, intent_type.value
class ApiUserRateLimiter: