Add vision support for Gemini models in Khoj

This commit is contained in:
Debanjum Singh Solanky 2024-10-18 19:13:06 -07:00
parent 0d6a54c10f
commit 3e39fac455
5 changed files with 67 additions and 27 deletions

View file

@ -13,7 +13,10 @@ from khoj.processor.conversation.google.utils import (
gemini_chat_completion_with_backoff, gemini_chat_completion_with_backoff,
gemini_completion_with_backoff, gemini_completion_with_backoff,
) )
from khoj.processor.conversation.utils import generate_chatml_messages_with_context from khoj.processor.conversation.utils import (
construct_structured_message,
generate_chatml_messages_with_context,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
@ -29,6 +32,8 @@ def extract_questions_gemini(
max_tokens=None, max_tokens=None,
location_data: LocationData = None, location_data: LocationData = None,
user: KhojUser = None, user: KhojUser = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
): ):
""" """
@ -70,17 +75,17 @@ def extract_questions_gemini(
text=text, text=text,
) )
messages = [ChatMessage(content=prompt, role="user")] prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModelOptions.ModelType.GOOGLE,
vision_enabled=vision_enabled,
)
model_kwargs = {"response_mime_type": "application/json"} messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
response = gemini_completion_with_backoff( response = gemini_send_message_to_model(
messages=messages, messages, api_key, model, response_type="json_object", temperature=temperature
system_prompt=system_prompt,
model_name=model,
temperature=temperature,
api_key=api_key,
model_kwargs=model_kwargs,
) )
# Extract, Clean Message from Gemini's Response # Extract, Clean Message from Gemini's Response
@ -102,7 +107,7 @@ def extract_questions_gemini(
return questions return questions
def gemini_send_message_to_model(messages, api_key, model, response_type="text"): def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
""" """
Send message to model Send message to model
""" """
@ -114,7 +119,12 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text")
# Get Response from Gemini # Get Response from Gemini
return gemini_completion_with_backoff( return gemini_completion_with_backoff(
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs messages=messages,
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
temperature=temperature,
model_kwargs=model_kwargs,
) )
@ -133,6 +143,8 @@ def converse_gemini(
location_data: LocationData = None, location_data: LocationData = None,
user_name: str = None, user_name: str = None,
agent: Agent = None, agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
): ):
""" """
Converse with user using Google's Gemini Converse with user using Google's Gemini
@ -187,6 +199,8 @@ def converse_gemini(
model_name=model, model_name=model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE, model_type=ChatModelOptions.ModelType.GOOGLE,
) )

View file

@ -1,8 +1,11 @@
import logging import logging
import random import random
from io import BytesIO
from threading import Thread from threading import Thread
import google.generativeai as genai import google.generativeai as genai
import PIL.Image
import requests
from google.generativeai.types.answer_types import FinishReason from google.generativeai.types.answer_types import FinishReason
from google.generativeai.types.generation_types import StopCandidateException from google.generativeai.types.generation_types import StopCandidateException
from google.generativeai.types.safety_types import ( from google.generativeai.types.safety_types import (
@ -53,14 +56,14 @@ def gemini_completion_with_backoff(
}, },
) )
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages] formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
# Start chat session. All messages up to the last are considered to be part of the chat history # Start chat session. All messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1]) chat_session = model.start_chat(history=formatted_messages[0:-1])
try: try:
# Generate the response. The last message is considered to be the current prompt # Generate the response. The last message is considered to be the current prompt
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0]) aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
return aggregated_response.text return aggregated_response.text
except StopCandidateException as e: except StopCandidateException as e:
response_message, _ = handle_gemini_response(e.args) response_message, _ = handle_gemini_response(e.args)
@ -117,11 +120,11 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
}, },
) )
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages] formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
# all messages up to the last are considered to be part of the chat history # all messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1]) chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt # the last message is considered to be the current prompt
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True): for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text message = message or chunk.text
g.send(message) g.send(message)
@ -191,14 +194,6 @@ def generate_safety_response(safety_ratings):
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]: def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
if len(messages) == 1:
messages[0].role = "user"
return messages, system_prompt
for message in messages:
if message.role == "assistant":
message.role = "model"
# Extract system message # Extract system message
system_prompt = system_prompt or "" system_prompt = system_prompt or ""
for message in messages.copy(): for message in messages.copy():
@ -207,4 +202,31 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
messages.remove(message) messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
for message in messages:
# Convert message content to string list from chatml dictionary list
if isinstance(message.content, list):
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
message.content = [
get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"]
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
]
elif isinstance(message.content, str):
message.content = [message.content]
if message.role == "assistant":
message.role = "model"
if len(messages) == 1:
messages[0].role = "user"
return messages, system_prompt return messages, system_prompt
def get_image_from_url(image_url: str) -> PIL.Image:
try:
response = requests.get(image_url)
response.raise_for_status() # Check if the request was successful
return PIL.Image.open(BytesIO(response.content))
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return None

View file

@ -152,7 +152,7 @@ def construct_structured_message(message: str, images: list[str], model_type: st
if not images or not vision_enabled: if not images or not vision_enabled:
return message return message
if model_type == ChatModelOptions.ModelType.OPENAI: if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]:
return [ return [
{"type": "text", "text": message}, {"type": "text", "text": message},
*[{"type": "image_url", "image_url": {"url": image}} for image in images], *[{"type": "image_url", "image_url": {"url": image}} for image in images],

View file

@ -452,12 +452,14 @@ async def extract_references_and_questions(
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
inferred_queries = extract_questions_gemini( inferred_queries = extract_questions_gemini(
defiltered_query, defiltered_query,
query_images=query_images,
model=chat_model, model=chat_model,
api_key=api_key, api_key=api_key,
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
max_tokens=conversation_config.max_prompt_size, max_tokens=conversation_config.max_prompt_size,
user=user, user=user,
vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,
) )

View file

@ -995,8 +995,9 @@ def generate_chat_response(
chat_response = converse_gemini( chat_response = converse_gemini(
compiled_references, compiled_references,
q, q,
online_results, query_images=query_images,
meta_log, online_results=online_results,
conversation_log=meta_log,
model=conversation_config.chat_model, model=conversation_config.chat_model,
api_key=api_key, api_key=api_key,
completion_func=partial_completion, completion_func=partial_completion,
@ -1006,6 +1007,7 @@ def generate_chat_response(
location_data=location_data, location_data=location_data,
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
vision_available=vision_available,
) )
metadata.update({"chat_model": conversation_config.chat_model}) metadata.update({"chat_model": conversation_config.chat_model})