mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Allow OpenAI API calling functions to save conversation traces
This commit is contained in:
parent
10c8fd3b2a
commit
384f394336
2 changed files with 57 additions and 11 deletions
|
@ -33,6 +33,7 @@ def extract_questions(
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
|
@ -82,7 +83,13 @@ def extract_questions(
|
||||||
messages = [ChatMessage(content=prompt, role="user")]
|
messages = [ChatMessage(content=prompt, role="user")]
|
||||||
|
|
||||||
response = send_message_to_model(
|
response = send_message_to_model(
|
||||||
messages, api_key, model, response_type="json_object", api_base_url=api_base_url, temperature=temperature
|
messages,
|
||||||
|
api_key,
|
||||||
|
model,
|
||||||
|
response_type="json_object",
|
||||||
|
api_base_url=api_base_url,
|
||||||
|
temperature=temperature,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
|
@ -103,7 +110,9 @@ def extract_questions(
|
||||||
return questions
|
return questions
|
||||||
|
|
||||||
|
|
||||||
def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None, temperature=0):
|
def send_message_to_model(
|
||||||
|
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
"""
|
"""
|
||||||
|
@ -116,6 +125,7 @@ def send_message_to_model(messages, api_key, model, response_type="text", api_ba
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
model_kwargs={"response_format": {"type": response_type}},
|
model_kwargs={"response_format": {"type": response_type}},
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,6 +147,7 @@ def converse(
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Converse with user using OpenAI's ChatGPT
|
Converse with user using OpenAI's ChatGPT
|
||||||
|
@ -209,4 +220,5 @@ def converse(
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
completion_func=completion_func,
|
completion_func=completion_func,
|
||||||
model_kwargs={"stop": ["Notes:\n["]},
|
model_kwargs={"stop": ["Notes:\n["]},
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
|
@ -12,7 +12,12 @@ from tenacity import (
|
||||||
wait_random_exponential,
|
wait_random_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khoj.processor.conversation.utils import ThreadedGenerator
|
from khoj.processor.conversation.utils import (
|
||||||
|
ThreadedGenerator,
|
||||||
|
commit_conversation_trace,
|
||||||
|
)
|
||||||
|
from khoj.utils import state
|
||||||
|
from khoj.utils.helpers import in_debug_mode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -33,7 +38,7 @@ openai_clients: Dict[str, openai.OpenAI] = {}
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def completion_with_backoff(
|
def completion_with_backoff(
|
||||||
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None
|
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {}
|
||||||
) -> str:
|
) -> str:
|
||||||
client_key = f"{openai_api_key}--{api_base_url}"
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
client: openai.OpenAI | None = openai_clients.get(client_key)
|
client: openai.OpenAI | None = openai_clients.get(client_key)
|
||||||
|
@ -77,6 +82,12 @@ def completion_with_backoff(
|
||||||
elif delta_chunk.content:
|
elif delta_chunk.content:
|
||||||
aggregated_response += delta_chunk.content
|
aggregated_response += delta_chunk.content
|
||||||
|
|
||||||
|
# Save conversation trace
|
||||||
|
tracer["chat_model"] = model
|
||||||
|
tracer["temperature"] = temperature
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
|
|
||||||
return aggregated_response
|
return aggregated_response
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,26 +114,37 @@ def chat_completion_with_backoff(
|
||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||||
t = Thread(
|
t = Thread(
|
||||||
target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs)
|
target=llm_thread,
|
||||||
|
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer),
|
||||||
)
|
)
|
||||||
t.start()
|
t.start()
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None):
|
def llm_thread(
|
||||||
|
g,
|
||||||
|
messages,
|
||||||
|
model_name,
|
||||||
|
temperature,
|
||||||
|
openai_api_key=None,
|
||||||
|
api_base_url=None,
|
||||||
|
model_kwargs=None,
|
||||||
|
tracer: dict = {},
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
client_key = f"{openai_api_key}--{api_base_url}"
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
if client_key not in openai_clients:
|
if client_key not in openai_clients:
|
||||||
client: openai.OpenAI = openai.OpenAI(
|
client = openai.OpenAI(
|
||||||
api_key=openai_api_key,
|
api_key=openai_api_key,
|
||||||
base_url=api_base_url,
|
base_url=api_base_url,
|
||||||
)
|
)
|
||||||
openai_clients[client_key] = client
|
openai_clients[client_key] = client
|
||||||
else:
|
else:
|
||||||
client: openai.OpenAI = openai_clients[client_key]
|
client = openai_clients[client_key]
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
stream = True
|
stream = True
|
||||||
|
@ -144,17 +166,29 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
|
||||||
**(model_kwargs or dict()),
|
**(model_kwargs or dict()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
aggregated_response = ""
|
||||||
if not stream:
|
if not stream:
|
||||||
g.send(chat.choices[0].message.content)
|
aggregated_response = chat.choices[0].message.content
|
||||||
|
g.send(aggregated_response)
|
||||||
else:
|
else:
|
||||||
for chunk in chat:
|
for chunk in chat:
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
continue
|
continue
|
||||||
delta_chunk = chunk.choices[0].delta
|
delta_chunk = chunk.choices[0].delta
|
||||||
|
text_chunk = ""
|
||||||
if isinstance(delta_chunk, str):
|
if isinstance(delta_chunk, str):
|
||||||
g.send(delta_chunk)
|
text_chunk = delta_chunk
|
||||||
elif delta_chunk.content:
|
elif delta_chunk.content:
|
||||||
g.send(delta_chunk.content)
|
text_chunk = delta_chunk.content
|
||||||
|
if text_chunk:
|
||||||
|
aggregated_response += text_chunk
|
||||||
|
g.send(text_chunk)
|
||||||
|
|
||||||
|
# Save conversation trace
|
||||||
|
tracer["chat_model"] = model_name
|
||||||
|
tracer["temperature"] = temperature
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in llm_thread: {e}", exc_info=True)
|
logger.error(f"Error in llm_thread: {e}", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
|
|
Loading…
Reference in a new issue