mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Allow Gemini API calling functions to save conversation traces
This commit is contained in:
parent
384f394336
commit
6fcd6a5659
2 changed files with 42 additions and 12 deletions
|
@ -35,6 +35,7 @@ def extract_questions_gemini(
|
|||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -85,7 +86,7 @@ def extract_questions_gemini(
|
|||
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
|
||||
|
||||
response = gemini_send_message_to_model(
|
||||
messages, api_key, model, response_type="json_object", temperature=temperature
|
||||
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Gemini's Response
|
||||
|
@ -107,7 +108,9 @@ def extract_questions_gemini(
|
|||
return questions
|
||||
|
||||
|
||||
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
|
||||
def gemini_send_message_to_model(
|
||||
messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={}
|
||||
):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
@ -125,6 +128,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text",
|
|||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
model_kwargs=model_kwargs,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
|
@ -145,6 +149,7 @@ def converse_gemini(
|
|||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
tracer={},
|
||||
):
|
||||
"""
|
||||
Converse with user using Google's Gemini
|
||||
|
@ -219,4 +224,5 @@ def converse_gemini(
|
|||
api_key=api_key,
|
||||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
|
|
@ -19,8 +19,13 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -35,7 +40,7 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192
|
|||
reraise=True,
|
||||
)
|
||||
def gemini_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
|
||||
) -> str:
|
||||
genai.configure(api_key=api_key)
|
||||
model_kwargs = model_kwargs or dict()
|
||||
|
@ -60,16 +65,23 @@ def gemini_completion_with_backoff(
|
|||
|
||||
try:
|
||||
# Generate the response. The last message is considered to be the current prompt
|
||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||
return aggregated_response.text
|
||||
response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||
response_text = response.text
|
||||
except StopCandidateException as e:
|
||||
response_message, _ = handle_gemini_response(e.args)
|
||||
response_text, _ = handle_gemini_response(e.args)
|
||||
# Respond with reason for stopping
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {response_message}.\n"
|
||||
f"LLM Response Prevented for {model_name}: {response_text}.\n"
|
||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||
)
|
||||
return response_message
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, response_text, tracer)
|
||||
|
||||
return response_text
|
||||
|
||||
|
||||
@retry(
|
||||
|
@ -88,17 +100,20 @@ def gemini_chat_completion_with_backoff(
|
|||
system_prompt,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=gemini_llm_thread,
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs),
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None):
|
||||
def gemini_llm_thread(
|
||||
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
|
||||
):
|
||||
try:
|
||||
genai.configure(api_key=api_key)
|
||||
model_kwargs = model_kwargs or dict()
|
||||
|
@ -117,16 +132,25 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
|
|||
},
|
||||
)
|
||||
|
||||
aggregated_response = ""
|
||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||
|
||||
# all messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
# the last message is considered to be the current prompt
|
||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
|
||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||
message = message or chunk.text
|
||||
aggregated_response += message
|
||||
g.send(message)
|
||||
if stopped:
|
||||
raise StopCandidateException(message)
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except StopCandidateException as e:
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
|
||||
|
|
Loading…
Reference in a new issue