Allow OpenAI API calling functions to save conversation traces

This commit is contained in:
Debanjum Singh Solanky 2024-10-23 20:01:06 -07:00
parent 10c8fd3b2a
commit 384f394336
2 changed files with 57 additions and 11 deletions

View file

@ -33,6 +33,7 @@ def extract_questions(
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_enabled: bool = False, vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
tracer: dict = {},
): ):
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@ -82,7 +83,13 @@ def extract_questions(
messages = [ChatMessage(content=prompt, role="user")] messages = [ChatMessage(content=prompt, role="user")]
response = send_message_to_model( response = send_message_to_model(
messages, api_key, model, response_type="json_object", api_base_url=api_base_url, temperature=temperature messages,
api_key,
model,
response_type="json_object",
api_base_url=api_base_url,
temperature=temperature,
tracer=tracer,
) )
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
@ -103,7 +110,9 @@ def extract_questions(
return questions return questions
def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None, temperature=0): def send_message_to_model(
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
):
""" """
Send message to model Send message to model
""" """
@ -116,6 +125,7 @@ def send_message_to_model(messages, api_key, model, response_type="text", api_ba
temperature=temperature, temperature=temperature,
api_base_url=api_base_url, api_base_url=api_base_url,
model_kwargs={"response_format": {"type": response_type}}, model_kwargs={"response_format": {"type": response_type}},
tracer=tracer,
) )
@ -137,6 +147,7 @@ def converse(
agent: Agent = None, agent: Agent = None,
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
tracer: dict = {},
): ):
""" """
Converse with user using OpenAI's ChatGPT Converse with user using OpenAI's ChatGPT
@ -209,4 +220,5 @@ def converse(
api_base_url=api_base_url, api_base_url=api_base_url,
completion_func=completion_func, completion_func=completion_func,
model_kwargs={"stop": ["Notes:\n["]}, model_kwargs={"stop": ["Notes:\n["]},
tracer=tracer,
) )

View file

@ -12,7 +12,12 @@ from tenacity import (
wait_random_exponential, wait_random_exponential,
) )
from khoj.processor.conversation.utils import ThreadedGenerator from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,7 +38,7 @@ openai_clients: Dict[str, openai.OpenAI] = {}
reraise=True, reraise=True,
) )
def completion_with_backoff( def completion_with_backoff(
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {}
) -> str: ) -> str:
client_key = f"{openai_api_key}--{api_base_url}" client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI | None = openai_clients.get(client_key) client: openai.OpenAI | None = openai_clients.get(client_key)
@ -77,6 +82,12 @@ def completion_with_backoff(
elif delta_chunk.content: elif delta_chunk.content:
aggregated_response += delta_chunk.content aggregated_response += delta_chunk.content
# Save conversation trace
tracer["chat_model"] = model
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
return aggregated_response return aggregated_response
@ -103,26 +114,37 @@ def chat_completion_with_backoff(
api_base_url=None, api_base_url=None,
completion_func=None, completion_func=None,
model_kwargs=None, model_kwargs=None,
tracer: dict = {},
): ):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread( t = Thread(
target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs) target=llm_thread,
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer),
) )
t.start() t.start()
return g return g
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None): def llm_thread(
g,
messages,
model_name,
temperature,
openai_api_key=None,
api_base_url=None,
model_kwargs=None,
tracer: dict = {},
):
try: try:
client_key = f"{openai_api_key}--{api_base_url}" client_key = f"{openai_api_key}--{api_base_url}"
if client_key not in openai_clients: if client_key not in openai_clients:
client: openai.OpenAI = openai.OpenAI( client = openai.OpenAI(
api_key=openai_api_key, api_key=openai_api_key,
base_url=api_base_url, base_url=api_base_url,
) )
openai_clients[client_key] = client openai_clients[client_key] = client
else: else:
client: openai.OpenAI = openai_clients[client_key] client = openai_clients[client_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages] formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
stream = True stream = True
@ -144,17 +166,29 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
**(model_kwargs or dict()), **(model_kwargs or dict()),
) )
aggregated_response = ""
if not stream: if not stream:
g.send(chat.choices[0].message.content) aggregated_response = chat.choices[0].message.content
g.send(aggregated_response)
else: else:
for chunk in chat: for chunk in chat:
if len(chunk.choices) == 0: if len(chunk.choices) == 0:
continue continue
delta_chunk = chunk.choices[0].delta delta_chunk = chunk.choices[0].delta
text_chunk = ""
if isinstance(delta_chunk, str): if isinstance(delta_chunk, str):
g.send(delta_chunk) text_chunk = delta_chunk
elif delta_chunk.content: elif delta_chunk.content:
g.send(delta_chunk.content) text_chunk = delta_chunk.content
if text_chunk:
aggregated_response += text_chunk
g.send(text_chunk)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e: except Exception as e:
logger.error(f"Error in llm_thread: {e}", exc_info=True) logger.error(f"Error in llm_thread: {e}", exc_info=True)
finally: finally: