mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Return typed image from image_with_url function for readability
This commit is contained in:
parent
8d588e0765
commit
01d740debd
3 changed files with 19 additions and 7 deletions
|
@ -141,13 +141,15 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=Non
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
content = []
|
content = []
|
||||||
# Sort the content as preferred if text comes after images
|
# Sort the content. Anthropic models prefer that text comes after images.
|
||||||
message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1)
|
message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1)
|
||||||
for idx, part in enumerate(message.content):
|
for idx, part in enumerate(message.content):
|
||||||
if part["type"] == "text":
|
if part["type"] == "text":
|
||||||
content.append({"type": "text", "text": part["text"]})
|
content.append({"type": "text", "text": part["text"]})
|
||||||
elif part["type"] == "image_url":
|
elif part["type"] == "image_url":
|
||||||
b64_image, media_type = get_image_from_url(part["image_url"]["url"], type="b64")
|
image = get_image_from_url(part["image_url"]["url"], type="b64")
|
||||||
|
# Prefix each image with text block enumerating the image number
|
||||||
|
# This helps the model reference the image in its response. Recommended by Anthropic
|
||||||
content.extend(
|
content.extend(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
|
@ -156,7 +158,7 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=Non
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"source": {"type": "base64", "media_type": media_type, "data": b64_image},
|
"source": {"type": "base64", "media_type": image.type, "data": image.content},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -204,7 +204,7 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
|
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
|
||||||
message.content = [
|
message.content = [
|
||||||
get_image_from_url(item["image_url"]["url"])[0] if item["type"] == "image_url" else item["text"]
|
get_image_from_url(item["image_url"]["url"]).content if item["type"] == "image_url" else item["text"]
|
||||||
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
|
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
|
||||||
]
|
]
|
||||||
elif isinstance(message.content, str):
|
elif isinstance(message.content, str):
|
||||||
|
|
|
@ -3,6 +3,7 @@ import logging
|
||||||
import math
|
import math
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import queue
|
import queue
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
|
@ -317,6 +318,12 @@ def remove_json_codeblock(response: str):
|
||||||
return response.removeprefix("```json").removesuffix("```")
|
return response.removeprefix("```json").removesuffix("```")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageWithType:
|
||||||
|
content: Any
|
||||||
|
type: str
|
||||||
|
|
||||||
|
|
||||||
def get_image_from_url(image_url: str, type="pil"):
|
def get_image_from_url(image_url: str, type="pil"):
|
||||||
try:
|
try:
|
||||||
response = requests.get(image_url)
|
response = requests.get(image_url)
|
||||||
|
@ -325,12 +332,15 @@ def get_image_from_url(image_url: str, type="pil"):
|
||||||
# Get content type from response or infer from URL
|
# Get content type from response or infer from URL
|
||||||
content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp"
|
content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp"
|
||||||
|
|
||||||
|
# Convert image to desired format
|
||||||
if type == "b64":
|
if type == "b64":
|
||||||
return base64.b64encode(response.content).decode("utf-8"), content_type
|
image_data = base64.b64encode(response.content).decode("utf-8")
|
||||||
elif type == "pil":
|
elif type == "pil":
|
||||||
return PIL.Image.open(BytesIO(response.content)), content_type
|
image_data = PIL.Image.open(BytesIO(response.content))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid image type: {type}")
|
raise ValueError(f"Invalid image type: {type}")
|
||||||
|
|
||||||
|
return ImageWithType(content=image_data, type=content_type)
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
||||||
return None, None
|
return ImageWithType(content=None, type=None)
|
||||||
|
|
Loading…
Reference in a new issue