From 01d740debd4d1857ec9bc6595d227dc17678449c Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 24 Oct 2024 17:49:37 -0700 Subject: [PATCH] Return typed image from image_with_url function for readability --- .../processor/conversation/anthropic/utils.py | 8 +++++--- src/khoj/processor/conversation/google/utils.py | 2 +- src/khoj/processor/conversation/utils.py | 16 +++++++++++++--- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index cc020b0a..a4a71a6d 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -141,13 +141,15 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=Non for message in messages: if isinstance(message.content, list): content = [] - # Sort the content as preferred if text comes after images + # Sort the content. Anthropic models prefer that text comes after images. message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1) for idx, part in enumerate(message.content): if part["type"] == "text": content.append({"type": "text", "text": part["text"]}) elif part["type"] == "image_url": - b64_image, media_type = get_image_from_url(part["image_url"]["url"], type="b64") + image = get_image_from_url(part["image_url"]["url"], type="b64") + # Prefix each image with text block enumerating the image number + # This helps the model reference the image in its response. Recommended by Anthropic content.extend( [ { @@ -156,7 +158,7 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=Non }, { "type": "image", - "source": {"type": "base64", "media_type": media_type, "data": b64_image}, + "source": {"type": "base64", "media_type": image.type, "data": image.content}, }, ] ) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index a4041a94..964fe80b 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -204,7 +204,7 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = if isinstance(message.content, list): # Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini) message.content = [ - get_image_from_url(item["image_url"]["url"])[0] if item["type"] == "image_url" else item["text"] + get_image_from_url(item["image_url"]["url"]).content if item["type"] == "image_url" else item["text"] for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1) ] elif isinstance(message.content, str): diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 943c5616..fb6d1909 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -3,6 +3,7 @@ import logging import math import mimetypes import queue +from dataclasses import dataclass from datetime import datetime from io import BytesIO from time import perf_counter @@ -317,6 +318,12 @@ def remove_json_codeblock(response: str): return response.removeprefix("```json").removesuffix("```") +@dataclass +class ImageWithType: + content: Any + type: str + + def get_image_from_url(image_url: str, type="pil"): try: response = requests.get(image_url) @@ -325,12 +332,15 @@ def get_image_from_url(image_url: str, type="pil"): # Get content type from response or infer from URL content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp" + # Convert image to desired format if type == "b64": - return base64.b64encode(response.content).decode("utf-8"), content_type + image_data = base64.b64encode(response.content).decode("utf-8") elif type == "pil": - return PIL.Image.open(BytesIO(response.content)), content_type + image_data = PIL.Image.open(BytesIO(response.content)) else: raise ValueError(f"Invalid image type: {type}") + + return ImageWithType(content=image_data, type=content_type) except requests.exceptions.RequestException as e: logger.error(f"Failed to get image from URL {image_url}: {e}") - return None, None + return ImageWithType(content=None, type=None)