mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Add vision support for Gemini models in Khoj
This commit is contained in:
parent
0d6a54c10f
commit
3e39fac455
5 changed files with 67 additions and 27 deletions
|
@ -13,7 +13,10 @@ from khoj.processor.conversation.google.utils import (
|
|||
gemini_chat_completion_with_backoff,
|
||||
gemini_completion_with_backoff,
|
||||
)
|
||||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
||||
from khoj.processor.conversation.utils import (
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
)
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
|
@ -29,6 +32,8 @@ def extract_questions_gemini(
|
|||
max_tokens=None,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
|
@ -70,17 +75,17 @@ def extract_questions_gemini(
|
|||
text=text,
|
||||
)
|
||||
|
||||
messages = [ChatMessage(content=prompt, role="user")]
|
||||
prompt = construct_structured_message(
|
||||
message=prompt,
|
||||
images=query_images,
|
||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||
vision_enabled=vision_enabled,
|
||||
)
|
||||
|
||||
model_kwargs = {"response_mime_type": "application/json"}
|
||||
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
|
||||
|
||||
response = gemini_completion_with_backoff(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
model_kwargs=model_kwargs,
|
||||
response = gemini_send_message_to_model(
|
||||
messages, api_key, model, response_type="json_object", temperature=temperature
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Gemini's Response
|
||||
|
@ -102,7 +107,7 @@ def extract_questions_gemini(
|
|||
return questions
|
||||
|
||||
|
||||
def gemini_send_message_to_model(messages, api_key, model, response_type="text"):
|
||||
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
@ -114,7 +119,12 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text")
|
|||
|
||||
# Get Response from Gemini
|
||||
return gemini_completion_with_backoff(
|
||||
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
@ -133,6 +143,8 @@ def converse_gemini(
|
|||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
):
|
||||
"""
|
||||
Converse with user using Google's Gemini
|
||||
|
@ -187,6 +199,8 @@ def converse_gemini(
|
|||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||
)
|
||||
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import logging
|
||||
import random
|
||||
from io import BytesIO
|
||||
from threading import Thread
|
||||
|
||||
import google.generativeai as genai
|
||||
import PIL.Image
|
||||
import requests
|
||||
from google.generativeai.types.answer_types import FinishReason
|
||||
from google.generativeai.types.generation_types import StopCandidateException
|
||||
from google.generativeai.types.safety_types import (
|
||||
|
@ -53,14 +56,14 @@ def gemini_completion_with_backoff(
|
|||
},
|
||||
)
|
||||
|
||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||
|
||||
# Start chat session. All messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
|
||||
try:
|
||||
# Generate the response. The last message is considered to be the current prompt
|
||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
|
||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||
return aggregated_response.text
|
||||
except StopCandidateException as e:
|
||||
response_message, _ = handle_gemini_response(e.args)
|
||||
|
@ -117,11 +120,11 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
|
|||
},
|
||||
)
|
||||
|
||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||
# all messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
# the last message is considered to be the current prompt
|
||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
|
||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
|
||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||
message = message or chunk.text
|
||||
g.send(message)
|
||||
|
@ -191,14 +194,6 @@ def generate_safety_response(safety_ratings):
|
|||
|
||||
|
||||
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
return messages, system_prompt
|
||||
|
||||
for message in messages:
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
# Extract system message
|
||||
system_prompt = system_prompt or ""
|
||||
for message in messages.copy():
|
||||
|
@ -207,4 +202,31 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
|
|||
messages.remove(message)
|
||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||
|
||||
for message in messages:
|
||||
# Convert message content to string list from chatml dictionary list
|
||||
if isinstance(message.content, list):
|
||||
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
|
||||
message.content = [
|
||||
get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"]
|
||||
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
|
||||
]
|
||||
elif isinstance(message.content, str):
|
||||
message.content = [message.content]
|
||||
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
|
||||
return messages, system_prompt
|
||||
|
||||
|
||||
def get_image_from_url(image_url: str) -> PIL.Image:
|
||||
try:
|
||||
response = requests.get(image_url)
|
||||
response.raise_for_status() # Check if the request was successful
|
||||
return PIL.Image.open(BytesIO(response.content))
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
||||
return None
|
||||
|
|
|
@ -152,7 +152,7 @@ def construct_structured_message(message: str, images: list[str], model_type: st
|
|||
if not images or not vision_enabled:
|
||||
return message
|
||||
|
||||
if model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]:
|
||||
return [
|
||||
{"type": "text", "text": message},
|
||||
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
|
||||
|
|
|
@ -452,12 +452,14 @@ async def extract_references_and_questions(
|
|||
chat_model = conversation_config.chat_model
|
||||
inferred_queries = extract_questions_gemini(
|
||||
defiltered_query,
|
||||
query_images=query_images,
|
||||
model=chat_model,
|
||||
api_key=api_key,
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
max_tokens=conversation_config.max_prompt_size,
|
||||
user=user,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
|
|
|
@ -995,8 +995,9 @@ def generate_chat_response(
|
|||
chat_response = converse_gemini(
|
||||
compiled_references,
|
||||
q,
|
||||
online_results,
|
||||
meta_log,
|
||||
query_images=query_images,
|
||||
online_results=online_results,
|
||||
conversation_log=meta_log,
|
||||
model=conversation_config.chat_model,
|
||||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
|
@ -1006,6 +1007,7 @@ def generate_chat_response(
|
|||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
)
|
||||
|
||||
metadata.update({"chat_model": conversation_config.chat_model})
|
||||
|
|
Loading…
Reference in a new issue