Allow Gemini API calling functions to save conversation traces

This commit is contained in:
Debanjum Singh Solanky 2024-10-24 14:15:53 -07:00
parent 384f394336
commit 6fcd6a5659
2 changed files with 42 additions and 12 deletions

View file

@ -35,6 +35,7 @@ def extract_questions_gemini(
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
tracer: dict = {},
):
"""
Infer search queries to retrieve relevant notes to answer user query
@ -85,7 +86,7 @@ def extract_questions_gemini(
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
response = gemini_send_message_to_model(
messages, api_key, model, response_type="json_object", temperature=temperature
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
)
# Extract, Clean Message from Gemini's Response
@ -107,7 +108,9 @@ def extract_questions_gemini(
return questions
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
def gemini_send_message_to_model(
messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={}
):
"""
Send message to model
"""
@ -125,6 +128,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text",
api_key=api_key,
temperature=temperature,
model_kwargs=model_kwargs,
tracer=tracer,
)
@ -145,6 +149,7 @@ def converse_gemini(
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
tracer={},
):
"""
Converse with user using Google's Gemini
@ -219,4 +224,5 @@ def converse_gemini(
api_key=api_key,
system_prompt=system_prompt,
completion_func=completion_func,
tracer=tracer,
)

View file

@ -19,8 +19,13 @@ from tenacity import (
wait_random_exponential,
)
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
from khoj.utils.helpers import is_none_or_empty
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
get_image_from_url,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__)
@ -35,7 +40,7 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192
reraise=True,
)
def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
) -> str:
genai.configure(api_key=api_key)
model_kwargs = model_kwargs or dict()
@ -60,16 +65,23 @@ def gemini_completion_with_backoff(
try:
# Generate the response. The last message is considered to be the current prompt
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
return aggregated_response.text
response = chat_session.send_message(formatted_messages[-1]["parts"])
response_text = response.text
except StopCandidateException as e:
response_message, _ = handle_gemini_response(e.args)
response_text, _ = handle_gemini_response(e.args)
# Respond with reason for stopping
logger.warning(
f"LLM Response Prevented for {model_name}: {response_message}.\n"
f"LLM Response Prevented for {model_name}: {response_text}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
return response_message
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, response_text, tracer)
return response_text
@retry(
@ -88,17 +100,20 @@ def gemini_chat_completion_with_backoff(
system_prompt,
completion_func=None,
model_kwargs=None,
tracer: dict = {},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=gemini_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs),
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
)
t.start()
return g
def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None):
def gemini_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
):
try:
genai.configure(api_key=api_key)
model_kwargs = model_kwargs or dict()
@ -117,16 +132,25 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
},
)
aggregated_response = ""
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
# all messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text
aggregated_response += message
g.send(message)
if stopped:
raise StopCandidateException(message)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
except StopCandidateException as e:
logger.warning(
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"