Allow Gemini API calling functions to save conversation traces

This commit is contained in:
Debanjum Singh Solanky 2024-10-24 14:15:53 -07:00
parent 384f394336
commit 6fcd6a5659
2 changed files with 42 additions and 12 deletions

View file

@ -35,6 +35,7 @@ def extract_questions_gemini(
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_enabled: bool = False, vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
tracer: dict = {},
): ):
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@ -85,7 +86,7 @@ def extract_questions_gemini(
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")] messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
response = gemini_send_message_to_model( response = gemini_send_message_to_model(
messages, api_key, model, response_type="json_object", temperature=temperature messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
) )
# Extract, Clean Message from Gemini's Response # Extract, Clean Message from Gemini's Response
@ -107,7 +108,9 @@ def extract_questions_gemini(
return questions return questions
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None): def gemini_send_message_to_model(
messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={}
):
""" """
Send message to model Send message to model
""" """
@ -125,6 +128,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text",
api_key=api_key, api_key=api_key,
temperature=temperature, temperature=temperature,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
tracer=tracer,
) )
@ -145,6 +149,7 @@ def converse_gemini(
agent: Agent = None, agent: Agent = None,
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
tracer={},
): ):
""" """
Converse with user using Google's Gemini Converse with user using Google's Gemini
@ -219,4 +224,5 @@ def converse_gemini(
api_key=api_key, api_key=api_key,
system_prompt=system_prompt, system_prompt=system_prompt,
completion_func=completion_func, completion_func=completion_func,
tracer=tracer,
) )

View file

@ -19,8 +19,13 @@ from tenacity import (
wait_random_exponential, wait_random_exponential,
) )
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url from khoj.processor.conversation.utils import (
from khoj.utils.helpers import is_none_or_empty ThreadedGenerator,
commit_conversation_trace,
get_image_from_url,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,7 +40,7 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192
reraise=True, reraise=True,
) )
def gemini_completion_with_backoff( def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
) -> str: ) -> str:
genai.configure(api_key=api_key) genai.configure(api_key=api_key)
model_kwargs = model_kwargs or dict() model_kwargs = model_kwargs or dict()
@ -60,16 +65,23 @@ def gemini_completion_with_backoff(
try: try:
# Generate the response. The last message is considered to be the current prompt # Generate the response. The last message is considered to be the current prompt
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"]) response = chat_session.send_message(formatted_messages[-1]["parts"])
return aggregated_response.text response_text = response.text
except StopCandidateException as e: except StopCandidateException as e:
response_message, _ = handle_gemini_response(e.args) response_text, _ = handle_gemini_response(e.args)
# Respond with reason for stopping # Respond with reason for stopping
logger.warning( logger.warning(
f"LLM Response Prevented for {model_name}: {response_message}.\n" f"LLM Response Prevented for {model_name}: {response_text}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}" + f"Last Message by {messages[-1].role}: {messages[-1].content}"
) )
return response_message
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, response_text, tracer)
return response_text
@retry( @retry(
@ -88,17 +100,20 @@ def gemini_chat_completion_with_backoff(
system_prompt, system_prompt,
completion_func=None, completion_func=None,
model_kwargs=None, model_kwargs=None,
tracer: dict = {},
): ):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread( t = Thread(
target=gemini_llm_thread, target=gemini_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs), args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
) )
t.start() t.start()
return g return g
def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None): def gemini_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
):
try: try:
genai.configure(api_key=api_key) genai.configure(api_key=api_key)
model_kwargs = model_kwargs or dict() model_kwargs = model_kwargs or dict()
@ -117,16 +132,25 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
}, },
) )
aggregated_response = ""
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages] formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
# all messages up to the last are considered to be part of the chat history # all messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1]) chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt # the last message is considered to be the current prompt
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True): for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text message = message or chunk.text
aggregated_response += message
g.send(message) g.send(message)
if stopped: if stopped:
raise StopCandidateException(message) raise StopCandidateException(message)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
except StopCandidateException as e: except StopCandidateException as e:
logger.warning( logger.warning(
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n" f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"