Reuse a single func to format conversation for Gemini

This deduplicates code and prevents logic from deviating across gemini
chat actors
This commit is contained in:
Debanjum Singh Solanky 2024-10-06 15:42:42 -07:00
parent 452e360175
commit 2b8f7f3efb
2 changed files with 25 additions and 17 deletions

View file

@ -9,6 +9,7 @@ from langchain.schema import ChatMessage
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import (
format_messages_for_gemini,
gemini_chat_completion_with_backoff,
gemini_completion_with_backoff,
)
@ -105,15 +106,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text")
"""
Send message to model
"""
system_prompt = None
if len(messages) == 1:
messages[0].role = "user"
else:
system_prompt = ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
messages, system_prompt = format_messages_for_gemini(messages)
model_kwargs = {}
if response_type == "json_object":
@ -195,14 +188,7 @@ def converse_gemini(
tokenizer_name=tokenizer_name,
)
for message in messages:
if message.role == "assistant":
message.role = "model"
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for Gemini: {truncated_messages}")

View file

@ -10,6 +10,7 @@ from google.generativeai.types.safety_types import (
HarmCategory,
HarmProbability,
)
from langchain.schema import ChatMessage
from tenacity import (
before_sleep_log,
retry,
@ -19,6 +20,7 @@ from tenacity import (
)
from khoj.processor.conversation.utils import ThreadedGenerator
from khoj.utils.helpers import is_none_or_empty
logger = logging.getLogger(__name__)
@ -182,3 +184,23 @@ def generate_safety_response(safety_ratings):
return safety_response_choice.format(
category=max_safety_category, probability=max_safety_rating.probability.name, discomfort_level=discomfort_level
)
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
if len(messages) == 1:
messages[0].role = "user"
return messages, system_prompt
for message in messages:
if message.role == "assistant":
message.role = "model"
# Extract system message
system_prompt = system_prompt or ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
return messages, system_prompt