diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index e8848806..10af8b4d 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -13,7 +13,10 @@ from khoj.processor.conversation.google.utils import ( gemini_chat_completion_with_backoff, gemini_completion_with_backoff, ) -from khoj.processor.conversation.utils import generate_chatml_messages_with_context +from khoj.processor.conversation.utils import ( + construct_structured_message, + generate_chatml_messages_with_context, +) from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -29,6 +32,8 @@ def extract_questions_gemini( max_tokens=None, location_data: LocationData = None, user: KhojUser = None, + query_images: Optional[list[str]] = None, + vision_enabled: bool = False, personality_context: Optional[str] = None, ): """ @@ -70,17 +75,17 @@ def extract_questions_gemini( text=text, ) - messages = [ChatMessage(content=prompt, role="user")] + prompt = construct_structured_message( + message=prompt, + images=query_images, + model_type=ChatModelOptions.ModelType.GOOGLE, + vision_enabled=vision_enabled, + ) - model_kwargs = {"response_mime_type": "application/json"} + messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")] - response = gemini_completion_with_backoff( - messages=messages, - system_prompt=system_prompt, - model_name=model, - temperature=temperature, - api_key=api_key, - model_kwargs=model_kwargs, + response = gemini_send_message_to_model( + messages, api_key, model, response_type="json_object", temperature=temperature ) # Extract, Clean Message from Gemini's Response @@ -102,7 +107,7 @@ def extract_questions_gemini( return questions -def gemini_send_message_to_model(messages, api_key, model, response_type="text"): +def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None): """ Send message to model """ @@ -114,7 +119,12 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text") # Get Response from Gemini return gemini_completion_with_backoff( - messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs + messages=messages, + system_prompt=system_prompt, + model_name=model, + api_key=api_key, + temperature=temperature, + model_kwargs=model_kwargs, ) @@ -133,6 +143,8 @@ def converse_gemini( location_data: LocationData = None, user_name: str = None, agent: Agent = None, + query_images: Optional[list[str]] = None, + vision_available: bool = False, ): """ Converse with user using Google's Gemini @@ -187,6 +199,8 @@ def converse_gemini( model_name=model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, + query_images=query_images, + vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.GOOGLE, ) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 5679ba4d..d19b02f2 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -1,8 +1,11 @@ import logging import random +from io import BytesIO from threading import Thread import google.generativeai as genai +import PIL.Image +import requests from google.generativeai.types.answer_types import FinishReason from google.generativeai.types.generation_types import StopCandidateException from google.generativeai.types.safety_types import ( @@ -53,14 +56,14 @@ def gemini_completion_with_backoff( }, ) - formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages] + formatted_messages = [{"role": message.role, "parts": message.content} for message in messages] # Start chat session. All messages up to the last are considered to be part of the chat history chat_session = model.start_chat(history=formatted_messages[0:-1]) try: # Generate the response. The last message is considered to be the current prompt - aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0]) + aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"]) return aggregated_response.text except StopCandidateException as e: response_message, _ = handle_gemini_response(e.args) @@ -117,11 +120,11 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k }, ) - formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages] + formatted_messages = [{"role": message.role, "parts": message.content} for message in messages] # all messages up to the last are considered to be part of the chat history chat_session = model.start_chat(history=formatted_messages[0:-1]) # the last message is considered to be the current prompt - for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True): + for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True): message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) message = message or chunk.text g.send(message) @@ -191,14 +194,6 @@ def generate_safety_response(safety_ratings): def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]: - if len(messages) == 1: - messages[0].role = "user" - return messages, system_prompt - - for message in messages: - if message.role == "assistant": - message.role = "model" - # Extract system message system_prompt = system_prompt or "" for message in messages.copy(): @@ -207,4 +202,31 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = messages.remove(message) system_prompt = None if is_none_or_empty(system_prompt) else system_prompt + for message in messages: + # Convert message content to string list from chatml dictionary list + if isinstance(message.content, list): + # Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini) + message.content = [ + get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"] + for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1) + ] + elif isinstance(message.content, str): + message.content = [message.content] + + if message.role == "assistant": + message.role = "model" + + if len(messages) == 1: + messages[0].role = "user" + return messages, system_prompt + + +def get_image_from_url(image_url: str) -> PIL.Image: + try: + response = requests.get(image_url) + response.raise_for_status() # Check if the request was successful + return PIL.Image.open(BytesIO(response.content)) + except requests.exceptions.RequestException as e: + logger.error(f"Failed to get image from URL {image_url}: {e}") + return None diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 8d799745..789be3a5 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -152,7 +152,7 @@ def construct_structured_message(message: str, images: list[str], model_type: st if not images or not vision_enabled: return message - if model_type == ChatModelOptions.ModelType.OPENAI: + if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]: return [ {"type": "text", "text": message}, *[{"type": "image_url", "image_url": {"url": image}} for image in images], diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 075c8c47..33edd61f 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -452,12 +452,14 @@ async def extract_references_and_questions( chat_model = conversation_config.chat_model inferred_queries = extract_questions_gemini( defiltered_query, + query_images=query_images, model=chat_model, api_key=api_key, conversation_log=meta_log, location_data=location_data, max_tokens=conversation_config.max_prompt_size, user=user, + vision_enabled=vision_enabled, personality_context=personality_context, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 7ed9c72d..739a3ad6 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -995,8 +995,9 @@ def generate_chat_response( chat_response = converse_gemini( compiled_references, q, - online_results, - meta_log, + query_images=query_images, + online_results=online_results, + conversation_log=meta_log, model=conversation_config.chat_model, api_key=api_key, completion_func=partial_completion, @@ -1006,6 +1007,7 @@ def generate_chat_response( location_data=location_data, user_name=user_name, agent=agent, + vision_available=vision_available, ) metadata.update({"chat_model": conversation_config.chat_model})