mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Make the get image from url function more versatile and reusable
It was previously added under the google utils. Now it can be used by other conversation processors as well. The updated function - can get both base64 encoded and PIL formatted images from url - will return the media type of the image as well in response
This commit is contained in:
parent
9f2c02d9f7
commit
82eac5a043
2 changed files with 26 additions and 15 deletions
|
@ -1,11 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from io import BytesIO
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import PIL.Image
|
|
||||||
import requests
|
|
||||||
from google.generativeai.types.answer_types import FinishReason
|
from google.generativeai.types.answer_types import FinishReason
|
||||||
from google.generativeai.types.generation_types import StopCandidateException
|
from google.generativeai.types.generation_types import StopCandidateException
|
||||||
from google.generativeai.types.safety_types import (
|
from google.generativeai.types.safety_types import (
|
||||||
|
@ -22,7 +19,7 @@ from tenacity import (
|
||||||
wait_random_exponential,
|
wait_random_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khoj.processor.conversation.utils import ThreadedGenerator
|
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
|
||||||
from khoj.utils.helpers import is_none_or_empty
|
from khoj.utils.helpers import is_none_or_empty
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -207,7 +204,7 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
|
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
|
||||||
message.content = [
|
message.content = [
|
||||||
get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"]
|
get_image_from_url(item["image_url"]["url"])[0] if item["type"] == "image_url" else item["text"]
|
||||||
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
|
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
|
||||||
]
|
]
|
||||||
elif isinstance(message.content, str):
|
elif isinstance(message.content, str):
|
||||||
|
@ -220,13 +217,3 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
|
||||||
messages[0].role = "user"
|
messages[0].role = "user"
|
||||||
|
|
||||||
return messages, system_prompt
|
return messages, system_prompt
|
||||||
|
|
||||||
|
|
||||||
def get_image_from_url(image_url: str) -> PIL.Image:
|
|
||||||
try:
|
|
||||||
response = requests.get(image_url)
|
|
||||||
response.raise_for_status() # Check if the request was successful
|
|
||||||
return PIL.Image.open(BytesIO(response.content))
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
|
@ -1,10 +1,15 @@
|
||||||
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import mimetypes
|
||||||
import queue
|
import queue
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from io import BytesIO
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
import requests
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
from llama_cpp.llama import Llama
|
from llama_cpp.llama import Llama
|
||||||
|
@ -306,3 +311,22 @@ def reciprocal_conversation_to_chatml(message_pair):
|
||||||
def remove_json_codeblock(response: str):
|
def remove_json_codeblock(response: str):
|
||||||
"""Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
|
"""Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
|
||||||
return response.removeprefix("```json").removesuffix("```")
|
return response.removeprefix("```json").removesuffix("```")
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_from_url(image_url: str, type="pil"):
|
||||||
|
try:
|
||||||
|
response = requests.get(image_url)
|
||||||
|
response.raise_for_status() # Check if the request was successful
|
||||||
|
|
||||||
|
# Get content type from response or infer from URL
|
||||||
|
content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp"
|
||||||
|
|
||||||
|
if type == "b64":
|
||||||
|
return base64.b64encode(response.content).decode("utf-8"), content_type
|
||||||
|
elif type == "pil":
|
||||||
|
return PIL.Image.open(BytesIO(response.content)), content_type
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid image type: {type}")
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
||||||
|
return None, None
|
||||||
|
|
Loading…
Reference in a new issue