diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index cb51abb4..5e403c7b 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -11,8 +11,12 @@ from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.utils import ( anthropic_chat_completion_with_backoff, anthropic_completion_with_backoff, + format_messages_for_anthropic, +) +from khoj.processor.conversation.utils import ( + construct_structured_message, + generate_chatml_messages_with_context, ) -from khoj.processor.conversation.utils import generate_chatml_messages_with_context from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -27,6 +31,8 @@ def extract_questions_anthropic( temperature=0.7, location_data: LocationData = None, user: KhojUser = None, + query_images: Optional[list[str]] = None, + vision_enabled: bool = False, personality_context: Optional[str] = None, ): """ @@ -68,6 +74,13 @@ def extract_questions_anthropic( text=text, ) + prompt = construct_structured_message( + message=prompt, + images=query_images, + model_type=ChatModelOptions.ModelType.ANTHROPIC, + vision_enabled=vision_enabled, + ) + messages = [ChatMessage(content=prompt, role="user")] response = anthropic_completion_with_backoff( @@ -101,17 +114,7 @@ def anthropic_send_message_to_model(messages, api_key, model): """ Send message to model """ - # Anthropic requires the first message to be a 'user' message, and the system prompt is not to be sent in the messages parameter - system_prompt = None - - if len(messages) == 1: - messages[0].role = "user" - else: - system_prompt = "" - for message in messages.copy(): - if message.role == "system": - system_prompt += message.content - messages.remove(message) + messages, system_prompt = format_messages_for_anthropic(messages) # Get Response from GPT. Don't use response_type because Anthropic doesn't support it. return anthropic_completion_with_backoff( @@ -127,7 +130,7 @@ def converse_anthropic( user_query, online_results: Optional[Dict[str, Dict]] = None, conversation_log={}, - model: Optional[str] = "claude-instant-1.2", + model: Optional[str] = "claude-3-5-sonnet-20241022", api_key: Optional[str] = None, completion_func=None, conversation_commands=[ConversationCommand.Default], @@ -136,6 +139,8 @@ def converse_anthropic( location_data: LocationData = None, user_name: str = None, agent: Agent = None, + query_images: Optional[list[str]] = None, + vision_available: bool = False, ): """ Converse with user using Anthropic's Claude @@ -189,17 +194,12 @@ def converse_anthropic( model_name=model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, + query_images=query_images, + vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.ANTHROPIC, ) - if len(messages) > 1: - if messages[0].role == "assistant": - messages = messages[1:] - - for message in messages.copy(): - if message.role == "system": - system_prompt += message.content - messages.remove(message) + messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages}) logger.debug(f"Conversation Context for Claude: {truncated_messages}") diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 79ccac4e..a4a71a6d 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -3,6 +3,7 @@ from threading import Thread from typing import Dict, List import anthropic +from langchain.schema import ChatMessage from tenacity import ( before_sleep_log, retry, @@ -11,7 +12,8 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator +from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url +from khoj.utils.helpers import is_none_or_empty logger = logging.getLogger(__name__) @@ -115,3 +117,51 @@ def anthropic_llm_thread( logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) finally: g.close() + + +def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=None): + """ + Format messages for Anthropic + """ + # Extract system prompt + system_prompt = system_prompt or "" + for message in messages.copy(): + if message.role == "system": + system_prompt += message.content + messages.remove(message) + system_prompt = None if is_none_or_empty(system_prompt) else system_prompt + + # Anthropic requires the first message to be a 'user' message + if len(messages) == 1: + messages[0].role = "user" + elif len(messages) > 1 and messages[0].role == "assistant": + messages = messages[1:] + + # Convert image urls to base64 encoded images in Anthropic message format + for message in messages: + if isinstance(message.content, list): + content = [] + # Sort the content. Anthropic models prefer that text comes after images. + message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1) + for idx, part in enumerate(message.content): + if part["type"] == "text": + content.append({"type": "text", "text": part["text"]}) + elif part["type"] == "image_url": + image = get_image_from_url(part["image_url"]["url"], type="b64") + # Prefix each image with text block enumerating the image number + # This helps the model reference the image in its response. Recommended by Anthropic + content.extend( + [ + { + "type": "text", + "text": f"Image {idx + 1}:", + }, + { + "type": "image", + "source": {"type": "base64", "media_type": image.type, "data": image.content}, + }, + ] + ) + message.content = content + + return messages, system_prompt diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index d19b02f2..964fe80b 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -1,11 +1,8 @@ import logging import random -from io import BytesIO from threading import Thread import google.generativeai as genai -import PIL.Image -import requests from google.generativeai.types.answer_types import FinishReason from google.generativeai.types.generation_types import StopCandidateException from google.generativeai.types.safety_types import ( @@ -22,7 +19,7 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator +from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url from khoj.utils.helpers import is_none_or_empty logger = logging.getLogger(__name__) @@ -207,7 +204,7 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = if isinstance(message.content, list): # Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini) message.content = [ - get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"] + get_image_from_url(item["image_url"]["url"]).content if item["type"] == "image_url" else item["text"] for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1) ] elif isinstance(message.content, str): @@ -220,13 +217,3 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = messages[0].role = "user" return messages, system_prompt - - -def get_image_from_url(image_url: str) -> PIL.Image: - try: - response = requests.get(image_url) - response.raise_for_status() # Check if the request was successful - return PIL.Image.open(BytesIO(response.content)) - except requests.exceptions.RequestException as e: - logger.error(f"Failed to get image from URL {image_url}: {e}") - return None diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 38db7477..7988cc43 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -619,7 +619,7 @@ AI: It's currently 28°C and partly cloudy in Bali. Q: Share a painting using the weather for Bali every morning. Khoj: {{"output": "automation"}} -Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON. +Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON. Do not say anything else. Chat History: {chat_history} diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e8e96314..fb6d1909 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -1,10 +1,16 @@ +import base64 import logging import math +import mimetypes import queue +from dataclasses import dataclass from datetime import datetime +from io import BytesIO from time import perf_counter from typing import Any, Dict, List, Optional +import PIL.Image +import requests import tiktoken from langchain.schema import ChatMessage from llama_cpp.llama import Llama @@ -152,7 +158,11 @@ def construct_structured_message(message: str, images: list[str], model_type: st if not images or not vision_enabled: return message - if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]: + if model_type in [ + ChatModelOptions.ModelType.OPENAI, + ChatModelOptions.ModelType.GOOGLE, + ChatModelOptions.ModelType.ANTHROPIC, + ]: return [ {"type": "text", "text": message}, *[{"type": "image_url", "image_url": {"url": image}} for image in images], @@ -306,3 +316,31 @@ def reciprocal_conversation_to_chatml(message_pair): def remove_json_codeblock(response: str): """Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models""" return response.removeprefix("```json").removesuffix("```") + + +@dataclass +class ImageWithType: + content: Any + type: str + + +def get_image_from_url(image_url: str, type="pil"): + try: + response = requests.get(image_url) + response.raise_for_status() # Check if the request was successful + + # Get content type from response or infer from URL + content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp" + + # Convert image to desired format + if type == "b64": + image_data = base64.b64encode(response.content).decode("utf-8") + elif type == "pil": + image_data = PIL.Image.open(BytesIO(response.content)) + else: + raise ValueError(f"Invalid image type: {type}") + + return ImageWithType(content=image_data, type=content_type) + except requests.exceptions.RequestException as e: + logger.error(f"Failed to get image from URL {image_url}: {e}") + return ImageWithType(content=None, type=None) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 8254da4d..f89ca87a 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -448,11 +448,13 @@ async def extract_references_and_questions( chat_model = conversation_config.chat_model inferred_queries = extract_questions_anthropic( defiltered_query, + query_images=query_images, model=chat_model, api_key=api_key, conversation_log=meta_log, location_data=location_data, user=user, + vision_enabled=vision_enabled, personality_context=personality_context, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index d323abfe..6cc44c4f 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -820,10 +820,13 @@ async def send_message_to_model_wrapper( conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) vision_available = conversation_config.vision_enabled if not vision_available and query_images: + logger.warning(f"Vision is not enabled for default model: {conversation_config.chat_model}.") vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() if vision_enabled_config: conversation_config = vision_enabled_config vision_available = True + if vision_available and query_images: + logger.info(f"Using {conversation_config.chat_model} model to understand {len(query_images)} images.") subscribed = await ais_user_subscribed(user) chat_model = conversation_config.chat_model @@ -1104,8 +1107,9 @@ def generate_chat_response( chat_response = converse_anthropic( compiled_references, q, - online_results, - meta_log, + query_images=query_images, + online_results=online_results, + conversation_log=meta_log, model=conversation_config.chat_model, api_key=api_key, completion_func=partial_completion, @@ -1115,6 +1119,7 @@ def generate_chat_response( location_data=location_data, user_name=user_name, agent=agent, + vision_available=vision_available, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key