mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Add vision support for Gemini models in Khoj
This commit is contained in:
parent
0d6a54c10f
commit
3e39fac455
5 changed files with 67 additions and 27 deletions
|
@ -13,7 +13,10 @@ from khoj.processor.conversation.google.utils import (
|
||||||
gemini_chat_completion_with_backoff,
|
gemini_chat_completion_with_backoff,
|
||||||
gemini_completion_with_backoff,
|
gemini_completion_with_backoff,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
from khoj.processor.conversation.utils import (
|
||||||
|
construct_structured_message,
|
||||||
|
generate_chatml_messages_with_context,
|
||||||
|
)
|
||||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||||
from khoj.utils.rawconfig import LocationData
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
|
@ -29,6 +32,8 @@ def extract_questions_gemini(
|
||||||
max_tokens=None,
|
max_tokens=None,
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
|
query_images: Optional[list[str]] = None,
|
||||||
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -70,17 +75,17 @@ def extract_questions_gemini(
|
||||||
text=text,
|
text=text,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [ChatMessage(content=prompt, role="user")]
|
prompt = construct_structured_message(
|
||||||
|
message=prompt,
|
||||||
|
images=query_images,
|
||||||
|
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
|
)
|
||||||
|
|
||||||
model_kwargs = {"response_mime_type": "application/json"}
|
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
|
||||||
|
|
||||||
response = gemini_completion_with_backoff(
|
response = gemini_send_message_to_model(
|
||||||
messages=messages,
|
messages, api_key, model, response_type="json_object", temperature=temperature
|
||||||
system_prompt=system_prompt,
|
|
||||||
model_name=model,
|
|
||||||
temperature=temperature,
|
|
||||||
api_key=api_key,
|
|
||||||
model_kwargs=model_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from Gemini's Response
|
# Extract, Clean Message from Gemini's Response
|
||||||
|
@ -102,7 +107,7 @@ def extract_questions_gemini(
|
||||||
return questions
|
return questions
|
||||||
|
|
||||||
|
|
||||||
def gemini_send_message_to_model(messages, api_key, model, response_type="text"):
|
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
|
||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
"""
|
"""
|
||||||
|
@ -114,7 +119,12 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text")
|
||||||
|
|
||||||
# Get Response from Gemini
|
# Get Response from Gemini
|
||||||
return gemini_completion_with_backoff(
|
return gemini_completion_with_backoff(
|
||||||
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs
|
messages=messages,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
model_name=model,
|
||||||
|
api_key=api_key,
|
||||||
|
temperature=temperature,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -133,6 +143,8 @@ def converse_gemini(
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
query_images: Optional[list[str]] = None,
|
||||||
|
vision_available: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Converse with user using Google's Gemini
|
Converse with user using Google's Gemini
|
||||||
|
@ -187,6 +199,8 @@ def converse_gemini(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
|
query_images=query_images,
|
||||||
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
from io import BytesIO
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
|
import PIL.Image
|
||||||
|
import requests
|
||||||
from google.generativeai.types.answer_types import FinishReason
|
from google.generativeai.types.answer_types import FinishReason
|
||||||
from google.generativeai.types.generation_types import StopCandidateException
|
from google.generativeai.types.generation_types import StopCandidateException
|
||||||
from google.generativeai.types.safety_types import (
|
from google.generativeai.types.safety_types import (
|
||||||
|
@ -53,14 +56,14 @@ def gemini_completion_with_backoff(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||||
|
|
||||||
# Start chat session. All messages up to the last are considered to be part of the chat history
|
# Start chat session. All messages up to the last are considered to be part of the chat history
|
||||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate the response. The last message is considered to be the current prompt
|
# Generate the response. The last message is considered to be the current prompt
|
||||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
|
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||||
return aggregated_response.text
|
return aggregated_response.text
|
||||||
except StopCandidateException as e:
|
except StopCandidateException as e:
|
||||||
response_message, _ = handle_gemini_response(e.args)
|
response_message, _ = handle_gemini_response(e.args)
|
||||||
|
@ -117,11 +120,11 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||||
# all messages up to the last are considered to be part of the chat history
|
# all messages up to the last are considered to be part of the chat history
|
||||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||||
# the last message is considered to be the current prompt
|
# the last message is considered to be the current prompt
|
||||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
|
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
|
||||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||||
message = message or chunk.text
|
message = message or chunk.text
|
||||||
g.send(message)
|
g.send(message)
|
||||||
|
@ -191,14 +194,6 @@ def generate_safety_response(safety_ratings):
|
||||||
|
|
||||||
|
|
||||||
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
|
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
|
||||||
if len(messages) == 1:
|
|
||||||
messages[0].role = "user"
|
|
||||||
return messages, system_prompt
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
if message.role == "assistant":
|
|
||||||
message.role = "model"
|
|
||||||
|
|
||||||
# Extract system message
|
# Extract system message
|
||||||
system_prompt = system_prompt or ""
|
system_prompt = system_prompt or ""
|
||||||
for message in messages.copy():
|
for message in messages.copy():
|
||||||
|
@ -207,4 +202,31 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
|
||||||
messages.remove(message)
|
messages.remove(message)
|
||||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
# Convert message content to string list from chatml dictionary list
|
||||||
|
if isinstance(message.content, list):
|
||||||
|
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
|
||||||
|
message.content = [
|
||||||
|
get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"]
|
||||||
|
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
|
||||||
|
]
|
||||||
|
elif isinstance(message.content, str):
|
||||||
|
message.content = [message.content]
|
||||||
|
|
||||||
|
if message.role == "assistant":
|
||||||
|
message.role = "model"
|
||||||
|
|
||||||
|
if len(messages) == 1:
|
||||||
|
messages[0].role = "user"
|
||||||
|
|
||||||
return messages, system_prompt
|
return messages, system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_from_url(image_url: str) -> PIL.Image:
|
||||||
|
try:
|
||||||
|
response = requests.get(image_url)
|
||||||
|
response.raise_for_status() # Check if the request was successful
|
||||||
|
return PIL.Image.open(BytesIO(response.content))
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
||||||
|
return None
|
||||||
|
|
|
@ -152,7 +152,7 @@ def construct_structured_message(message: str, images: list[str], model_type: st
|
||||||
if not images or not vision_enabled:
|
if not images or not vision_enabled:
|
||||||
return message
|
return message
|
||||||
|
|
||||||
if model_type == ChatModelOptions.ModelType.OPENAI:
|
if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]:
|
||||||
return [
|
return [
|
||||||
{"type": "text", "text": message},
|
{"type": "text", "text": message},
|
||||||
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
|
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
|
||||||
|
|
|
@ -452,12 +452,14 @@ async def extract_references_and_questions(
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
inferred_queries = extract_questions_gemini(
|
inferred_queries = extract_questions_gemini(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
|
query_images=query_images,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
max_tokens=conversation_config.max_prompt_size,
|
max_tokens=conversation_config.max_prompt_size,
|
||||||
user=user,
|
user=user,
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -995,8 +995,9 @@ def generate_chat_response(
|
||||||
chat_response = converse_gemini(
|
chat_response = converse_gemini(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
q,
|
q,
|
||||||
online_results,
|
query_images=query_images,
|
||||||
meta_log,
|
online_results=online_results,
|
||||||
|
conversation_log=meta_log,
|
||||||
model=conversation_config.chat_model,
|
model=conversation_config.chat_model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
|
@ -1006,6 +1007,7 @@ def generate_chat_response(
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
vision_available=vision_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata.update({"chat_model": conversation_config.chat_model})
|
metadata.update({"chat_model": conversation_config.chat_model})
|
||||||
|
|
Loading…
Reference in a new issue