Return typed image from image_with_url function for readability

This commit is contained in:
Debanjum Singh Solanky 2024-10-24 17:49:37 -07:00
parent 8d588e0765
commit 01d740debd
3 changed files with 19 additions and 7 deletions

View file

@ -141,13 +141,15 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=Non
for message in messages: for message in messages:
if isinstance(message.content, list): if isinstance(message.content, list):
content = [] content = []
# Sort the content as preferred if text comes after images # Sort the content. Anthropic models prefer that text comes after images.
message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1) message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1)
for idx, part in enumerate(message.content): for idx, part in enumerate(message.content):
if part["type"] == "text": if part["type"] == "text":
content.append({"type": "text", "text": part["text"]}) content.append({"type": "text", "text": part["text"]})
elif part["type"] == "image_url": elif part["type"] == "image_url":
b64_image, media_type = get_image_from_url(part["image_url"]["url"], type="b64") image = get_image_from_url(part["image_url"]["url"], type="b64")
# Prefix each image with text block enumerating the image number
# This helps the model reference the image in its response. Recommended by Anthropic
content.extend( content.extend(
[ [
{ {
@ -156,7 +158,7 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=Non
}, },
{ {
"type": "image", "type": "image",
"source": {"type": "base64", "media_type": media_type, "data": b64_image}, "source": {"type": "base64", "media_type": image.type, "data": image.content},
}, },
] ]
) )

View file

@ -204,7 +204,7 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
if isinstance(message.content, list): if isinstance(message.content, list):
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini) # Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
message.content = [ message.content = [
get_image_from_url(item["image_url"]["url"])[0] if item["type"] == "image_url" else item["text"] get_image_from_url(item["image_url"]["url"]).content if item["type"] == "image_url" else item["text"]
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1) for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
] ]
elif isinstance(message.content, str): elif isinstance(message.content, str):

View file

@ -3,6 +3,7 @@ import logging
import math import math
import mimetypes import mimetypes
import queue import queue
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from io import BytesIO from io import BytesIO
from time import perf_counter from time import perf_counter
@ -317,6 +318,12 @@ def remove_json_codeblock(response: str):
return response.removeprefix("```json").removesuffix("```") return response.removeprefix("```json").removesuffix("```")
@dataclass
class ImageWithType:
content: Any
type: str
def get_image_from_url(image_url: str, type="pil"): def get_image_from_url(image_url: str, type="pil"):
try: try:
response = requests.get(image_url) response = requests.get(image_url)
@ -325,12 +332,15 @@ def get_image_from_url(image_url: str, type="pil"):
# Get content type from response or infer from URL # Get content type from response or infer from URL
content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp" content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp"
# Convert image to desired format
if type == "b64": if type == "b64":
return base64.b64encode(response.content).decode("utf-8"), content_type image_data = base64.b64encode(response.content).decode("utf-8")
elif type == "pil": elif type == "pil":
return PIL.Image.open(BytesIO(response.content)), content_type image_data = PIL.Image.open(BytesIO(response.content))
else: else:
raise ValueError(f"Invalid image type: {type}") raise ValueError(f"Invalid image type: {type}")
return ImageWithType(content=image_data, type=content_type)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}") logger.error(f"Failed to get image from URL {image_url}: {e}")
return None, None return ImageWithType(content=None, type=None)