mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Allow OpenAI API calling functions to save conversation traces
This commit is contained in:
parent
10c8fd3b2a
commit
384f394336
2 changed files with 57 additions and 11 deletions
|
@ -33,6 +33,7 @@ def extract_questions(
|
|||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -82,7 +83,13 @@ def extract_questions(
|
|||
messages = [ChatMessage(content=prompt, role="user")]
|
||||
|
||||
response = send_message_to_model(
|
||||
messages, api_key, model, response_type="json_object", api_base_url=api_base_url, temperature=temperature
|
||||
messages,
|
||||
api_key,
|
||||
model,
|
||||
response_type="json_object",
|
||||
api_base_url=api_base_url,
|
||||
temperature=temperature,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
|
@ -103,7 +110,9 @@ def extract_questions(
|
|||
return questions
|
||||
|
||||
|
||||
def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None, temperature=0):
|
||||
def send_message_to_model(
|
||||
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
|
||||
):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
@ -116,6 +125,7 @@ def send_message_to_model(messages, api_key, model, response_type="text", api_ba
|
|||
temperature=temperature,
|
||||
api_base_url=api_base_url,
|
||||
model_kwargs={"response_format": {"type": response_type}},
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
|
@ -137,6 +147,7 @@ def converse(
|
|||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Converse with user using OpenAI's ChatGPT
|
||||
|
@ -209,4 +220,5 @@ def converse(
|
|||
api_base_url=api_base_url,
|
||||
completion_func=completion_func,
|
||||
model_kwargs={"stop": ["Notes:\n["]},
|
||||
tracer=tracer,
|
||||
)
|
||||
|
|
|
@ -12,7 +12,12 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -33,7 +38,7 @@ openai_clients: Dict[str, openai.OpenAI] = {}
|
|||
reraise=True,
|
||||
)
|
||||
def completion_with_backoff(
|
||||
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None
|
||||
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {}
|
||||
) -> str:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client: openai.OpenAI | None = openai_clients.get(client_key)
|
||||
|
@ -77,6 +82,12 @@ def completion_with_backoff(
|
|||
elif delta_chunk.content:
|
||||
aggregated_response += delta_chunk.content
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
return aggregated_response
|
||||
|
||||
|
||||
|
@ -103,26 +114,37 @@ def chat_completion_with_backoff(
|
|||
api_base_url=None,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs)
|
||||
target=llm_thread,
|
||||
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None):
|
||||
def llm_thread(
|
||||
g,
|
||||
messages,
|
||||
model_name,
|
||||
temperature,
|
||||
openai_api_key=None,
|
||||
api_base_url=None,
|
||||
model_kwargs=None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
try:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
if client_key not in openai_clients:
|
||||
client: openai.OpenAI = openai.OpenAI(
|
||||
client = openai.OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=api_base_url,
|
||||
)
|
||||
openai_clients[client_key] = client
|
||||
else:
|
||||
client: openai.OpenAI = openai_clients[client_key]
|
||||
client = openai_clients[client_key]
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
stream = True
|
||||
|
@ -144,17 +166,29 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
|
|||
**(model_kwargs or dict()),
|
||||
)
|
||||
|
||||
aggregated_response = ""
|
||||
if not stream:
|
||||
g.send(chat.choices[0].message.content)
|
||||
aggregated_response = chat.choices[0].message.content
|
||||
g.send(aggregated_response)
|
||||
else:
|
||||
for chunk in chat:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta_chunk = chunk.choices[0].delta
|
||||
text_chunk = ""
|
||||
if isinstance(delta_chunk, str):
|
||||
g.send(delta_chunk)
|
||||
text_chunk = delta_chunk
|
||||
elif delta_chunk.content:
|
||||
g.send(delta_chunk.content)
|
||||
text_chunk = delta_chunk.content
|
||||
if text_chunk:
|
||||
aggregated_response += text_chunk
|
||||
g.send(text_chunk)
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in llm_thread: {e}", exc_info=True)
|
||||
finally:
|
||||
|
|
Loading…
Reference in a new issue