mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Reuse a single func to format conversation for Gemini
This deduplicates code and prevents logic from deviating across gemini chat actors
This commit is contained in:
parent
452e360175
commit
2b8f7f3efb
2 changed files with 25 additions and 17 deletions
|
@ -9,6 +9,7 @@ from langchain.schema import ChatMessage
|
|||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.google.utils import (
|
||||
format_messages_for_gemini,
|
||||
gemini_chat_completion_with_backoff,
|
||||
gemini_completion_with_backoff,
|
||||
)
|
||||
|
@ -105,15 +106,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text")
|
|||
"""
|
||||
Send message to model
|
||||
"""
|
||||
system_prompt = None
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
else:
|
||||
system_prompt = ""
|
||||
for message in messages.copy():
|
||||
if message.role == "system":
|
||||
system_prompt += message.content
|
||||
messages.remove(message)
|
||||
messages, system_prompt = format_messages_for_gemini(messages)
|
||||
|
||||
model_kwargs = {}
|
||||
if response_type == "json_object":
|
||||
|
@ -195,14 +188,7 @@ def converse_gemini(
|
|||
tokenizer_name=tokenizer_name,
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
for message in messages.copy():
|
||||
if message.role == "system":
|
||||
system_prompt += message.content
|
||||
messages.remove(message)
|
||||
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
|
||||
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
|
||||
logger.debug(f"Conversation Context for Gemini: {truncated_messages}")
|
||||
|
|
|
@ -10,6 +10,7 @@ from google.generativeai.types.safety_types import (
|
|||
HarmCategory,
|
||||
HarmProbability,
|
||||
)
|
||||
from langchain.schema import ChatMessage
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
|
@ -19,6 +20,7 @@ from tenacity import (
|
|||
)
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -182,3 +184,23 @@ def generate_safety_response(safety_ratings):
|
|||
return safety_response_choice.format(
|
||||
category=max_safety_category, probability=max_safety_rating.probability.name, discomfort_level=discomfort_level
|
||||
)
|
||||
|
||||
|
||||
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
return messages, system_prompt
|
||||
|
||||
for message in messages:
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
# Extract system message
|
||||
system_prompt = system_prompt or ""
|
||||
for message in messages.copy():
|
||||
if message.role == "system":
|
||||
system_prompt += message.content
|
||||
messages.remove(message)
|
||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||
|
||||
return messages, system_prompt
|
||||
|
|
Loading…
Add table
Reference in a new issue