Add vision support for Gemini models in Khoj

This commit is contained in:
Debanjum Singh Solanky 2024-10-18 19:13:06 -07:00
parent 0d6a54c10f
commit 3e39fac455
5 changed files with 67 additions and 27 deletions

View file

@ -13,7 +13,10 @@ from khoj.processor.conversation.google.utils import (
gemini_chat_completion_with_backoff,
gemini_completion_with_backoff,
)
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.processor.conversation.utils import (
construct_structured_message,
generate_chatml_messages_with_context,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
@ -29,6 +32,8 @@ def extract_questions_gemini(
max_tokens=None,
location_data: LocationData = None,
user: KhojUser = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
):
"""
@ -70,17 +75,17 @@ def extract_questions_gemini(
text=text,
)
messages = [ChatMessage(content=prompt, role="user")]
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModelOptions.ModelType.GOOGLE,
vision_enabled=vision_enabled,
)
model_kwargs = {"response_mime_type": "application/json"}
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
response = gemini_completion_with_backoff(
messages=messages,
system_prompt=system_prompt,
model_name=model,
temperature=temperature,
api_key=api_key,
model_kwargs=model_kwargs,
response = gemini_send_message_to_model(
messages, api_key, model, response_type="json_object", temperature=temperature
)
# Extract, Clean Message from Gemini's Response
@ -102,7 +107,7 @@ def extract_questions_gemini(
return questions
def gemini_send_message_to_model(messages, api_key, model, response_type="text"):
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
"""
Send message to model
"""
@ -114,7 +119,12 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text")
# Get Response from Gemini
return gemini_completion_with_backoff(
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs
messages=messages,
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
temperature=temperature,
model_kwargs=model_kwargs,
)
@ -133,6 +143,8 @@ def converse_gemini(
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
):
"""
Converse with user using Google's Gemini
@ -187,6 +199,8 @@ def converse_gemini(
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE,
)

View file

@ -1,8 +1,11 @@
import logging
import random
from io import BytesIO
from threading import Thread
import google.generativeai as genai
import PIL.Image
import requests
from google.generativeai.types.answer_types import FinishReason
from google.generativeai.types.generation_types import StopCandidateException
from google.generativeai.types.safety_types import (
@ -53,14 +56,14 @@ def gemini_completion_with_backoff(
},
)
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
# Start chat session. All messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
try:
# Generate the response. The last message is considered to be the current prompt
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
return aggregated_response.text
except StopCandidateException as e:
response_message, _ = handle_gemini_response(e.args)
@ -117,11 +120,11 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
},
)
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
# all messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text
g.send(message)
@ -191,14 +194,6 @@ def generate_safety_response(safety_ratings):
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
if len(messages) == 1:
messages[0].role = "user"
return messages, system_prompt
for message in messages:
if message.role == "assistant":
message.role = "model"
# Extract system message
system_prompt = system_prompt or ""
for message in messages.copy():
@ -207,4 +202,31 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
for message in messages:
# Convert message content to string list from chatml dictionary list
if isinstance(message.content, list):
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
message.content = [
get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"]
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
]
elif isinstance(message.content, str):
message.content = [message.content]
if message.role == "assistant":
message.role = "model"
if len(messages) == 1:
messages[0].role = "user"
return messages, system_prompt
def get_image_from_url(image_url: str) -> PIL.Image:
try:
response = requests.get(image_url)
response.raise_for_status() # Check if the request was successful
return PIL.Image.open(BytesIO(response.content))
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return None

View file

@ -152,7 +152,7 @@ def construct_structured_message(message: str, images: list[str], model_type: st
if not images or not vision_enabled:
return message
if model_type == ChatModelOptions.ModelType.OPENAI:
if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]:
return [
{"type": "text", "text": message},
*[{"type": "image_url", "image_url": {"url": image}} for image in images],

View file

@ -452,12 +452,14 @@ async def extract_references_and_questions(
chat_model = conversation_config.chat_model
inferred_queries = extract_questions_gemini(
defiltered_query,
query_images=query_images,
model=chat_model,
api_key=api_key,
conversation_log=meta_log,
location_data=location_data,
max_tokens=conversation_config.max_prompt_size,
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
)

View file

@ -995,8 +995,9 @@ def generate_chat_response(
chat_response = converse_gemini(
compiled_references,
q,
online_results,
meta_log,
query_images=query_images,
online_results=online_results,
conversation_log=meta_log,
model=conversation_config.chat_model,
api_key=api_key,
completion_func=partial_completion,
@ -1006,6 +1007,7 @@ def generate_chat_response(
location_data=location_data,
user_name=user_name,
agent=agent,
vision_available=vision_available,
)
metadata.update({"chat_model": conversation_config.chat_model})