Make the get image from url function more versatile and reusable

It was previously added under the google utils. Now it can be used by
other conversation processors as well.

The updated function
- can get both base64 encoded and PIL formatted images from url
- will return the media type of the image as well in response
This commit is contained in:
Debanjum Singh Solanky 2024-10-23 03:52:46 -07:00
parent 9f2c02d9f7
commit 82eac5a043
2 changed files with 26 additions and 15 deletions

View file

@ -1,11 +1,8 @@
import logging import logging
import random import random
from io import BytesIO
from threading import Thread from threading import Thread
import google.generativeai as genai import google.generativeai as genai
import PIL.Image
import requests
from google.generativeai.types.answer_types import FinishReason from google.generativeai.types.answer_types import FinishReason
from google.generativeai.types.generation_types import StopCandidateException from google.generativeai.types.generation_types import StopCandidateException
from google.generativeai.types.safety_types import ( from google.generativeai.types.safety_types import (
@ -22,7 +19,7 @@ from tenacity import (
wait_random_exponential, wait_random_exponential,
) )
from khoj.processor.conversation.utils import ThreadedGenerator from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
from khoj.utils.helpers import is_none_or_empty from khoj.utils.helpers import is_none_or_empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -207,7 +204,7 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
if isinstance(message.content, list): if isinstance(message.content, list):
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini) # Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
message.content = [ message.content = [
get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"] get_image_from_url(item["image_url"]["url"])[0] if item["type"] == "image_url" else item["text"]
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1) for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
] ]
elif isinstance(message.content, str): elif isinstance(message.content, str):
@ -220,13 +217,3 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
messages[0].role = "user" messages[0].role = "user"
return messages, system_prompt return messages, system_prompt
def get_image_from_url(image_url: str) -> PIL.Image:
try:
response = requests.get(image_url)
response.raise_for_status() # Check if the request was successful
return PIL.Image.open(BytesIO(response.content))
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return None

View file

@ -1,10 +1,15 @@
import base64
import logging import logging
import math import math
import mimetypes
import queue import queue
from datetime import datetime from datetime import datetime
from io import BytesIO
from time import perf_counter from time import perf_counter
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import PIL.Image
import requests
import tiktoken import tiktoken
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from llama_cpp.llama import Llama from llama_cpp.llama import Llama
@ -306,3 +311,22 @@ def reciprocal_conversation_to_chatml(message_pair):
def remove_json_codeblock(response: str): def remove_json_codeblock(response: str):
"""Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models""" """Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
return response.removeprefix("```json").removesuffix("```") return response.removeprefix("```json").removesuffix("```")
def get_image_from_url(image_url: str, type="pil"):
try:
response = requests.get(image_url)
response.raise_for_status() # Check if the request was successful
# Get content type from response or infer from URL
content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp"
if type == "b64":
return base64.b64encode(response.content).decode("utf-8"), content_type
elif type == "pil":
return PIL.Image.open(BytesIO(response.content)), content_type
else:
raise ValueError(f"Invalid image type: {type}")
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return None, None