Give Vision to Anthropic models in Khoj (#948)

### Major
- Give Vision to Anthropic models in Khoj

### Minor
- Reuse logic to format messages for chat with anthropic models
- Make the get image from url function more versatile and reusable
- Encourage output mode chat actor to output only json and nothing else
This commit is contained in:
Debanjum 2024-10-24 18:02:38 -07:00 committed by GitHub
commit adee5a3e20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 123 additions and 41 deletions

View file

@ -11,8 +11,12 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import ( from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff, anthropic_chat_completion_with_backoff,
anthropic_completion_with_backoff, anthropic_completion_with_backoff,
format_messages_for_anthropic,
)
from khoj.processor.conversation.utils import (
construct_structured_message,
generate_chatml_messages_with_context,
) )
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
@ -27,6 +31,8 @@ def extract_questions_anthropic(
temperature=0.7, temperature=0.7,
location_data: LocationData = None, location_data: LocationData = None,
user: KhojUser = None, user: KhojUser = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
): ):
""" """
@ -68,6 +74,13 @@ def extract_questions_anthropic(
text=text, text=text,
) )
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
vision_enabled=vision_enabled,
)
messages = [ChatMessage(content=prompt, role="user")] messages = [ChatMessage(content=prompt, role="user")]
response = anthropic_completion_with_backoff( response = anthropic_completion_with_backoff(
@ -101,17 +114,7 @@ def anthropic_send_message_to_model(messages, api_key, model):
""" """
Send message to model Send message to model
""" """
# Anthropic requires the first message to be a 'user' message, and the system prompt is not to be sent in the messages parameter messages, system_prompt = format_messages_for_anthropic(messages)
system_prompt = None
if len(messages) == 1:
messages[0].role = "user"
else:
system_prompt = ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
# Get Response from GPT. Don't use response_type because Anthropic doesn't support it. # Get Response from GPT. Don't use response_type because Anthropic doesn't support it.
return anthropic_completion_with_backoff( return anthropic_completion_with_backoff(
@ -127,7 +130,7 @@ def converse_anthropic(
user_query, user_query,
online_results: Optional[Dict[str, Dict]] = None, online_results: Optional[Dict[str, Dict]] = None,
conversation_log={}, conversation_log={},
model: Optional[str] = "claude-instant-1.2", model: Optional[str] = "claude-3-5-sonnet-20241022",
api_key: Optional[str] = None, api_key: Optional[str] = None,
completion_func=None, completion_func=None,
conversation_commands=[ConversationCommand.Default], conversation_commands=[ConversationCommand.Default],
@ -136,6 +139,8 @@ def converse_anthropic(
location_data: LocationData = None, location_data: LocationData = None,
user_name: str = None, user_name: str = None,
agent: Agent = None, agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
): ):
""" """
Converse with user using Anthropic's Claude Converse with user using Anthropic's Claude
@ -189,17 +194,12 @@ def converse_anthropic(
model_name=model, model_name=model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.ANTHROPIC, model_type=ChatModelOptions.ModelType.ANTHROPIC,
) )
if len(messages) > 1: messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
if messages[0].role == "assistant":
messages = messages[1:]
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages}) truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for Claude: {truncated_messages}") logger.debug(f"Conversation Context for Claude: {truncated_messages}")

View file

@ -3,6 +3,7 @@ from threading import Thread
from typing import Dict, List from typing import Dict, List
import anthropic import anthropic
from langchain.schema import ChatMessage
from tenacity import ( from tenacity import (
before_sleep_log, before_sleep_log,
retry, retry,
@ -11,7 +12,8 @@ from tenacity import (
wait_random_exponential, wait_random_exponential,
) )
from khoj.processor.conversation.utils import ThreadedGenerator from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
from khoj.utils.helpers import is_none_or_empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -115,3 +117,51 @@ def anthropic_llm_thread(
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
finally: finally:
g.close() g.close()
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=None):
"""
Format messages for Anthropic
"""
# Extract system prompt
system_prompt = system_prompt or ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
# Anthropic requires the first message to be a 'user' message
if len(messages) == 1:
messages[0].role = "user"
elif len(messages) > 1 and messages[0].role == "assistant":
messages = messages[1:]
# Convert image urls to base64 encoded images in Anthropic message format
for message in messages:
if isinstance(message.content, list):
content = []
# Sort the content. Anthropic models prefer that text comes after images.
message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1)
for idx, part in enumerate(message.content):
if part["type"] == "text":
content.append({"type": "text", "text": part["text"]})
elif part["type"] == "image_url":
image = get_image_from_url(part["image_url"]["url"], type="b64")
# Prefix each image with text block enumerating the image number
# This helps the model reference the image in its response. Recommended by Anthropic
content.extend(
[
{
"type": "text",
"text": f"Image {idx + 1}:",
},
{
"type": "image",
"source": {"type": "base64", "media_type": image.type, "data": image.content},
},
]
)
message.content = content
return messages, system_prompt

View file

@ -1,11 +1,8 @@
import logging import logging
import random import random
from io import BytesIO
from threading import Thread from threading import Thread
import google.generativeai as genai import google.generativeai as genai
import PIL.Image
import requests
from google.generativeai.types.answer_types import FinishReason from google.generativeai.types.answer_types import FinishReason
from google.generativeai.types.generation_types import StopCandidateException from google.generativeai.types.generation_types import StopCandidateException
from google.generativeai.types.safety_types import ( from google.generativeai.types.safety_types import (
@ -22,7 +19,7 @@ from tenacity import (
wait_random_exponential, wait_random_exponential,
) )
from khoj.processor.conversation.utils import ThreadedGenerator from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
from khoj.utils.helpers import is_none_or_empty from khoj.utils.helpers import is_none_or_empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -207,7 +204,7 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
if isinstance(message.content, list): if isinstance(message.content, list):
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini) # Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
message.content = [ message.content = [
get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"] get_image_from_url(item["image_url"]["url"]).content if item["type"] == "image_url" else item["text"]
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1) for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
] ]
elif isinstance(message.content, str): elif isinstance(message.content, str):
@ -220,13 +217,3 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
messages[0].role = "user" messages[0].role = "user"
return messages, system_prompt return messages, system_prompt
def get_image_from_url(image_url: str) -> PIL.Image:
try:
response = requests.get(image_url)
response.raise_for_status() # Check if the request was successful
return PIL.Image.open(BytesIO(response.content))
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return None

View file

@ -619,7 +619,7 @@ AI: It's currently 28°C and partly cloudy in Bali.
Q: Share a painting using the weather for Bali every morning. Q: Share a painting using the weather for Bali every morning.
Khoj: {{"output": "automation"}} Khoj: {{"output": "automation"}}
Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON. Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON. Do not say anything else.
Chat History: Chat History:
{chat_history} {chat_history}

View file

@ -1,10 +1,16 @@
import base64
import logging import logging
import math import math
import mimetypes
import queue import queue
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from io import BytesIO
from time import perf_counter from time import perf_counter
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import PIL.Image
import requests
import tiktoken import tiktoken
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from llama_cpp.llama import Llama from llama_cpp.llama import Llama
@ -152,7 +158,11 @@ def construct_structured_message(message: str, images: list[str], model_type: st
if not images or not vision_enabled: if not images or not vision_enabled:
return message return message
if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]: if model_type in [
ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.GOOGLE,
ChatModelOptions.ModelType.ANTHROPIC,
]:
return [ return [
{"type": "text", "text": message}, {"type": "text", "text": message},
*[{"type": "image_url", "image_url": {"url": image}} for image in images], *[{"type": "image_url", "image_url": {"url": image}} for image in images],
@ -306,3 +316,31 @@ def reciprocal_conversation_to_chatml(message_pair):
def remove_json_codeblock(response: str): def remove_json_codeblock(response: str):
"""Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models""" """Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
return response.removeprefix("```json").removesuffix("```") return response.removeprefix("```json").removesuffix("```")
@dataclass
class ImageWithType:
content: Any
type: str
def get_image_from_url(image_url: str, type="pil"):
try:
response = requests.get(image_url)
response.raise_for_status() # Check if the request was successful
# Get content type from response or infer from URL
content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp"
# Convert image to desired format
if type == "b64":
image_data = base64.b64encode(response.content).decode("utf-8")
elif type == "pil":
image_data = PIL.Image.open(BytesIO(response.content))
else:
raise ValueError(f"Invalid image type: {type}")
return ImageWithType(content=image_data, type=content_type)
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return ImageWithType(content=None, type=None)

View file

@ -448,11 +448,13 @@ async def extract_references_and_questions(
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
inferred_queries = extract_questions_anthropic( inferred_queries = extract_questions_anthropic(
defiltered_query, defiltered_query,
query_images=query_images,
model=chat_model, model=chat_model,
api_key=api_key, api_key=api_key,
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
user=user, user=user,
vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:

View file

@ -820,10 +820,13 @@ async def send_message_to_model_wrapper(
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled vision_available = conversation_config.vision_enabled
if not vision_available and query_images: if not vision_available and query_images:
logger.warning(f"Vision is not enabled for default model: {conversation_config.chat_model}.")
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config: if vision_enabled_config:
conversation_config = vision_enabled_config conversation_config = vision_enabled_config
vision_available = True vision_available = True
if vision_available and query_images:
logger.info(f"Using {conversation_config.chat_model} model to understand {len(query_images)} images.")
subscribed = await ais_user_subscribed(user) subscribed = await ais_user_subscribed(user)
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
@ -1104,8 +1107,9 @@ def generate_chat_response(
chat_response = converse_anthropic( chat_response = converse_anthropic(
compiled_references, compiled_references,
q, q,
online_results, query_images=query_images,
meta_log, online_results=online_results,
conversation_log=meta_log,
model=conversation_config.chat_model, model=conversation_config.chat_model,
api_key=api_key, api_key=api_key,
completion_func=partial_completion, completion_func=partial_completion,
@ -1115,6 +1119,7 @@ def generate_chat_response(
location_data=location_data, location_data=location_data,
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
vision_available=vision_available,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key