Allow OpenAI API calling functions to save conversation traces

This commit is contained in:
Debanjum Singh Solanky 2024-10-23 20:01:06 -07:00
parent 10c8fd3b2a
commit 384f394336
2 changed files with 57 additions and 11 deletions

View file

@ -33,6 +33,7 @@ def extract_questions(
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
tracer: dict = {},
):
"""
Infer search queries to retrieve relevant notes to answer user query
@ -82,7 +83,13 @@ def extract_questions(
messages = [ChatMessage(content=prompt, role="user")]
response = send_message_to_model(
messages, api_key, model, response_type="json_object", api_base_url=api_base_url, temperature=temperature
messages,
api_key,
model,
response_type="json_object",
api_base_url=api_base_url,
temperature=temperature,
tracer=tracer,
)
# Extract, Clean Message from GPT's Response
@ -103,7 +110,9 @@ def extract_questions(
return questions
def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None, temperature=0):
def send_message_to_model(
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
):
"""
Send message to model
"""
@ -116,6 +125,7 @@ def send_message_to_model(messages, api_key, model, response_type="text", api_ba
temperature=temperature,
api_base_url=api_base_url,
model_kwargs={"response_format": {"type": response_type}},
tracer=tracer,
)
@ -137,6 +147,7 @@ def converse(
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
tracer: dict = {},
):
"""
Converse with user using OpenAI's ChatGPT
@ -209,4 +220,5 @@ def converse(
api_base_url=api_base_url,
completion_func=completion_func,
model_kwargs={"stop": ["Notes:\n["]},
tracer=tracer,
)

View file

@ -12,7 +12,12 @@ from tenacity import (
wait_random_exponential,
)
from khoj.processor.conversation.utils import ThreadedGenerator
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode
logger = logging.getLogger(__name__)
@ -33,7 +38,7 @@ openai_clients: Dict[str, openai.OpenAI] = {}
reraise=True,
)
def completion_with_backoff(
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {}
) -> str:
client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI | None = openai_clients.get(client_key)
@ -77,6 +82,12 @@ def completion_with_backoff(
elif delta_chunk.content:
aggregated_response += delta_chunk.content
# Save conversation trace
tracer["chat_model"] = model
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
return aggregated_response
@ -103,26 +114,37 @@ def chat_completion_with_backoff(
api_base_url=None,
completion_func=None,
model_kwargs=None,
tracer: dict = {},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs)
target=llm_thread,
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer),
)
t.start()
return g
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None):
def llm_thread(
g,
messages,
model_name,
temperature,
openai_api_key=None,
api_base_url=None,
model_kwargs=None,
tracer: dict = {},
):
try:
client_key = f"{openai_api_key}--{api_base_url}"
if client_key not in openai_clients:
client: openai.OpenAI = openai.OpenAI(
client = openai.OpenAI(
api_key=openai_api_key,
base_url=api_base_url,
)
openai_clients[client_key] = client
else:
client: openai.OpenAI = openai_clients[client_key]
client = openai_clients[client_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
stream = True
@ -144,17 +166,29 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
**(model_kwargs or dict()),
)
aggregated_response = ""
if not stream:
g.send(chat.choices[0].message.content)
aggregated_response = chat.choices[0].message.content
g.send(aggregated_response)
else:
for chunk in chat:
if len(chunk.choices) == 0:
continue
delta_chunk = chunk.choices[0].delta
text_chunk = ""
if isinstance(delta_chunk, str):
g.send(delta_chunk)
text_chunk = delta_chunk
elif delta_chunk.content:
g.send(delta_chunk.content)
text_chunk = delta_chunk.content
if text_chunk:
aggregated_response += text_chunk
g.send(text_chunk)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in llm_thread: {e}", exc_info=True)
finally: