mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Add Prompt Tracer to Visualize, Analyze and Debug Khoj's Train of Thought (#951)
## Overview Use git to capture prompt traces of khoj's train of thought. View, analyze and debug them using your favorite git client (e.g vscode, magit). - Each commit captures an interaction with an LLM The commit writes the query, response and system message each to a separate file in the repo. The commit message captures the chat model, Khoj version and other metadata - Each conversation turn can have multiple interactions with an LLM (e.g Khoj's train of thought) - Each new conversation turn forks from and merges back into its conversation branch - Each new conversation branches from the user branch - Each new user branches from root commit on the main branch ## Usage 1. Set `KHOJ_DEBUG=true` or start khoj in very verbose mode with `khoj -vv` to turn on prompt tracing 2. Chat with Khoj as usual 3. Open the promptrace git repo to view the generated prompt traces using your favorite git porcelain. The Khoj prompt trace git repo is created at `/tmp/khoj_promptrace` by default. You can configure the prompt trace directory by setting the `PROMPTRACE_DIR`environment variable. ## Implementation - Add utility functions to capture prompt traces using git (via `gitpython`) - Make each model provider in Khoj commit their LLM interactions with promptrace - Weave chat metadata from chat API through all chat actors and commit it to the prompt trace
This commit is contained in:
commit
1c920273dd
14 changed files with 439 additions and 57 deletions
|
@ -119,6 +119,7 @@ dev = [
|
|||
"mypy >= 1.0.1",
|
||||
"black >= 23.1.0",
|
||||
"pre-commit >= 3.0.4",
|
||||
"gitpython ~= 3.1.43",
|
||||
]
|
||||
|
||||
[tool.hatch.version]
|
||||
|
|
|
@ -34,6 +34,7 @@ def extract_questions_anthropic(
|
|||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -89,6 +90,7 @@ def extract_questions_anthropic(
|
|||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Claude's Response
|
||||
|
@ -110,7 +112,7 @@ def extract_questions_anthropic(
|
|||
return questions
|
||||
|
||||
|
||||
def anthropic_send_message_to_model(messages, api_key, model):
|
||||
def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
@ -122,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model):
|
|||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
|
@ -141,6 +144,7 @@ def converse_anthropic(
|
|||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Converse with user using Anthropic's Claude
|
||||
|
@ -213,4 +217,5 @@ def converse_anthropic(
|
|||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
|
|
@ -12,8 +12,13 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -30,7 +35,7 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
|
|||
reraise=True,
|
||||
)
|
||||
def anthropic_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={}
|
||||
) -> str:
|
||||
if api_key not in anthropic_clients:
|
||||
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
||||
|
@ -58,6 +63,12 @@ def anthropic_completion_with_backoff(
|
|||
for text in stream.text_stream:
|
||||
aggregated_response += text
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
return aggregated_response
|
||||
|
||||
|
||||
|
@ -78,18 +89,19 @@ def anthropic_chat_completion_with_backoff(
|
|||
max_prompt_size=None,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
tracer={},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=anthropic_llm_thread,
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs, tracer),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def anthropic_llm_thread(
|
||||
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
|
||||
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None, tracer={}
|
||||
):
|
||||
try:
|
||||
if api_key not in anthropic_clients:
|
||||
|
@ -102,6 +114,7 @@ def anthropic_llm_thread(
|
|||
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
|
||||
]
|
||||
|
||||
aggregated_response = ""
|
||||
with client.messages.stream(
|
||||
messages=formatted_messages,
|
||||
model=model_name, # type: ignore
|
||||
|
@ -112,7 +125,14 @@ def anthropic_llm_thread(
|
|||
**(model_kwargs or dict()),
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
aggregated_response += text
|
||||
g.send(text)
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
|
||||
finally:
|
||||
|
|
|
@ -35,6 +35,7 @@ def extract_questions_gemini(
|
|||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -85,7 +86,7 @@ def extract_questions_gemini(
|
|||
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
|
||||
|
||||
response = gemini_send_message_to_model(
|
||||
messages, api_key, model, response_type="json_object", temperature=temperature
|
||||
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Gemini's Response
|
||||
|
@ -107,7 +108,9 @@ def extract_questions_gemini(
|
|||
return questions
|
||||
|
||||
|
||||
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
|
||||
def gemini_send_message_to_model(
|
||||
messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={}
|
||||
):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
@ -125,6 +128,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text",
|
|||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
model_kwargs=model_kwargs,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
|
@ -145,6 +149,7 @@ def converse_gemini(
|
|||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
tracer={},
|
||||
):
|
||||
"""
|
||||
Converse with user using Google's Gemini
|
||||
|
@ -217,4 +222,5 @@ def converse_gemini(
|
|||
api_key=api_key,
|
||||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
|
|
@ -19,8 +19,13 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -35,7 +40,7 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192
|
|||
reraise=True,
|
||||
)
|
||||
def gemini_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
|
||||
) -> str:
|
||||
genai.configure(api_key=api_key)
|
||||
model_kwargs = model_kwargs or dict()
|
||||
|
@ -60,16 +65,23 @@ def gemini_completion_with_backoff(
|
|||
|
||||
try:
|
||||
# Generate the response. The last message is considered to be the current prompt
|
||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||
return aggregated_response.text
|
||||
response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||
response_text = response.text
|
||||
except StopCandidateException as e:
|
||||
response_message, _ = handle_gemini_response(e.args)
|
||||
response_text, _ = handle_gemini_response(e.args)
|
||||
# Respond with reason for stopping
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {response_message}.\n"
|
||||
f"LLM Response Prevented for {model_name}: {response_text}.\n"
|
||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||
)
|
||||
return response_message
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, response_text, tracer)
|
||||
|
||||
return response_text
|
||||
|
||||
|
||||
@retry(
|
||||
|
@ -88,17 +100,20 @@ def gemini_chat_completion_with_backoff(
|
|||
system_prompt,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=gemini_llm_thread,
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs),
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None):
|
||||
def gemini_llm_thread(
|
||||
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
|
||||
):
|
||||
try:
|
||||
genai.configure(api_key=api_key)
|
||||
model_kwargs = model_kwargs or dict()
|
||||
|
@ -117,16 +132,25 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
|
|||
},
|
||||
)
|
||||
|
||||
aggregated_response = ""
|
||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||
|
||||
# all messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
# the last message is considered to be the current prompt
|
||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
|
||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||
message = message or chunk.text
|
||||
aggregated_response += message
|
||||
g.send(message)
|
||||
if stopped:
|
||||
raise StopCandidateException(message)
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except StopCandidateException as e:
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
|
||||
|
|
|
@ -12,11 +12,12 @@ from khoj.processor.conversation import prompts
|
|||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
generate_chatml_messages_with_context,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -34,6 +35,7 @@ def extract_questions_offline(
|
|||
max_prompt_size: int = None,
|
||||
temperature: float = 0.7,
|
||||
personality_context: Optional[str] = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -94,6 +96,7 @@ def extract_questions_offline(
|
|||
max_prompt_size=max_prompt_size,
|
||||
temperature=temperature,
|
||||
response_type="json_object",
|
||||
tracer=tracer,
|
||||
)
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
|
@ -146,6 +149,7 @@ def converse_offline(
|
|||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||
"""
|
||||
Converse with user using Llama
|
||||
|
@ -153,8 +157,9 @@ def converse_offline(
|
|||
# Initialize Variables
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||
tracer["chat_model"] = model
|
||||
|
||||
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||
current_date = datetime.now()
|
||||
|
||||
if agent and agent.personality:
|
||||
|
@ -215,13 +220,14 @@ def converse_offline(
|
|||
logger.debug(f"Conversation Context for {model}: {truncated_messages}")
|
||||
|
||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
|
||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
|
||||
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
|
||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||
aggregated_response = ""
|
||||
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
|
@ -229,7 +235,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
|
|||
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
||||
)
|
||||
for response in response_iterator:
|
||||
g.send(response["choices"][0]["delta"].get("content", ""))
|
||||
response_delta = response["choices"][0]["delta"].get("content", "")
|
||||
aggregated_response += response_delta
|
||||
g.send(response_delta)
|
||||
|
||||
# Save conversation trace
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
g.close()
|
||||
|
@ -244,6 +257,7 @@ def send_message_to_model_offline(
|
|||
stop=[],
|
||||
max_prompt_size: int = None,
|
||||
response_type: str = "text",
|
||||
tracer: dict = {},
|
||||
):
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
|
@ -251,7 +265,17 @@ def send_message_to_model_offline(
|
|||
response = offline_chat_model.create_chat_completion(
|
||||
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return response
|
||||
else:
|
||||
return response["choices"][0]["message"].get("content", "")
|
||||
|
||||
response_text = response["choices"][0]["message"].get("content", "")
|
||||
|
||||
# Save conversation trace for non-streaming responses
|
||||
# Streamed responses need to be saved by the calling function
|
||||
tracer["chat_model"] = model
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, response_text, tracer)
|
||||
|
||||
return response_text
|
||||
|
|
|
@ -33,6 +33,7 @@ def extract_questions(
|
|||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -82,7 +83,13 @@ def extract_questions(
|
|||
messages = [ChatMessage(content=prompt, role="user")]
|
||||
|
||||
response = send_message_to_model(
|
||||
messages, api_key, model, response_type="json_object", api_base_url=api_base_url, temperature=temperature
|
||||
messages,
|
||||
api_key,
|
||||
model,
|
||||
response_type="json_object",
|
||||
api_base_url=api_base_url,
|
||||
temperature=temperature,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
|
@ -103,7 +110,9 @@ def extract_questions(
|
|||
return questions
|
||||
|
||||
|
||||
def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None, temperature=0):
|
||||
def send_message_to_model(
|
||||
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
|
||||
):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
@ -116,6 +125,7 @@ def send_message_to_model(messages, api_key, model, response_type="text", api_ba
|
|||
temperature=temperature,
|
||||
api_base_url=api_base_url,
|
||||
model_kwargs={"response_format": {"type": response_type}},
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
|
@ -137,6 +147,7 @@ def converse(
|
|||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Converse with user using OpenAI's ChatGPT
|
||||
|
@ -207,4 +218,5 @@ def converse(
|
|||
api_base_url=api_base_url,
|
||||
completion_func=completion_func,
|
||||
model_kwargs={"stop": ["Notes:\n["]},
|
||||
tracer=tracer,
|
||||
)
|
||||
|
|
|
@ -12,7 +12,12 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -33,7 +38,7 @@ openai_clients: Dict[str, openai.OpenAI] = {}
|
|||
reraise=True,
|
||||
)
|
||||
def completion_with_backoff(
|
||||
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None
|
||||
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {}
|
||||
) -> str:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client: openai.OpenAI | None = openai_clients.get(client_key)
|
||||
|
@ -77,6 +82,12 @@ def completion_with_backoff(
|
|||
elif delta_chunk.content:
|
||||
aggregated_response += delta_chunk.content
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
return aggregated_response
|
||||
|
||||
|
||||
|
@ -103,26 +114,37 @@ def chat_completion_with_backoff(
|
|||
api_base_url=None,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs)
|
||||
target=llm_thread,
|
||||
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None):
|
||||
def llm_thread(
|
||||
g,
|
||||
messages,
|
||||
model_name,
|
||||
temperature,
|
||||
openai_api_key=None,
|
||||
api_base_url=None,
|
||||
model_kwargs=None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
try:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
if client_key not in openai_clients:
|
||||
client: openai.OpenAI = openai.OpenAI(
|
||||
client = openai.OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=api_base_url,
|
||||
)
|
||||
openai_clients[client_key] = client
|
||||
else:
|
||||
client: openai.OpenAI = openai_clients[client_key]
|
||||
client = openai_clients[client_key]
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
stream = True
|
||||
|
@ -144,17 +166,29 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
|
|||
**(model_kwargs or dict()),
|
||||
)
|
||||
|
||||
aggregated_response = ""
|
||||
if not stream:
|
||||
g.send(chat.choices[0].message.content)
|
||||
aggregated_response = chat.choices[0].message.content
|
||||
g.send(aggregated_response)
|
||||
else:
|
||||
for chunk in chat:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta_chunk = chunk.choices[0].delta
|
||||
text_chunk = ""
|
||||
if isinstance(delta_chunk, str):
|
||||
g.send(delta_chunk)
|
||||
text_chunk = delta_chunk
|
||||
elif delta_chunk.content:
|
||||
g.send(delta_chunk.content)
|
||||
text_chunk = delta_chunk.content
|
||||
if text_chunk:
|
||||
aggregated_response += text_chunk
|
||||
g.send(text_chunk)
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in llm_thread: {e}", exc_info=True)
|
||||
finally:
|
||||
|
|
|
@ -2,6 +2,7 @@ import base64
|
|||
import logging
|
||||
import math
|
||||
import mimetypes
|
||||
import os
|
||||
import queue
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
@ -12,6 +13,8 @@ from typing import Any, Dict, List, Optional
|
|||
import PIL.Image
|
||||
import requests
|
||||
import tiktoken
|
||||
import yaml
|
||||
from git import Repo
|
||||
from langchain.schema import ChatMessage
|
||||
from llama_cpp.llama import Llama
|
||||
from transformers import AutoTokenizer
|
||||
|
@ -21,7 +24,7 @@ from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
|||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
model_to_prompt_size = {
|
||||
|
@ -117,6 +120,7 @@ def save_to_conversation_log(
|
|||
conversation_id: str = None,
|
||||
automation_id: str = None,
|
||||
query_images: List[str] = None,
|
||||
tracer: Dict[str, Any] = {},
|
||||
):
|
||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
updated_conversation = message_to_log(
|
||||
|
@ -142,6 +146,9 @@ def save_to_conversation_log(
|
|||
user_message=q,
|
||||
)
|
||||
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
merge_message_into_conversation_trace(q, chat_response, tracer)
|
||||
|
||||
logger.info(
|
||||
f"""
|
||||
Saved Conversation Turn
|
||||
|
@ -354,3 +361,163 @@ def get_image_from_url(image_url: str, type="pil"):
|
|||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
||||
return ImageWithType(content=None, type=None)
|
||||
|
||||
|
||||
def commit_conversation_trace(
|
||||
session: list[ChatMessage],
|
||||
response: str | list[dict],
|
||||
tracer: dict,
|
||||
system_message: str | list[dict] = "",
|
||||
repo_path: str = "/tmp/promptrace",
|
||||
) -> str:
|
||||
"""
|
||||
Save trace of conversation step using git. Useful to visualize, compare and debug traces.
|
||||
Returns the path to the repository.
|
||||
"""
|
||||
# Serialize session, system message and response to yaml
|
||||
system_message_yaml = yaml.dump(system_message, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
response_yaml = yaml.dump(response, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
formatted_session = [{"role": message.role, "content": message.content} for message in session]
|
||||
session_yaml = yaml.dump(formatted_session, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
query = (
|
||||
yaml.dump(session[-1].content, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
.strip()
|
||||
.removeprefix("'")
|
||||
.removesuffix("'")
|
||||
) # Extract serialized query from chat session
|
||||
|
||||
# Extract chat metadata for session
|
||||
uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid")
|
||||
|
||||
# Infer repository path from environment variable or provided path
|
||||
repo_path = os.getenv("PROMPTRACE_DIR", repo_path)
|
||||
|
||||
try:
|
||||
# Prepare git repository
|
||||
os.makedirs(repo_path, exist_ok=True)
|
||||
repo = Repo.init(repo_path)
|
||||
|
||||
# Remove post-commit hook if it exists
|
||||
hooks_dir = os.path.join(repo_path, ".git", "hooks")
|
||||
post_commit_hook = os.path.join(hooks_dir, "post-commit")
|
||||
if os.path.exists(post_commit_hook):
|
||||
os.remove(post_commit_hook)
|
||||
|
||||
# Configure git user if not set
|
||||
if not repo.config_reader().has_option("user", "email"):
|
||||
repo.config_writer().set_value("user", "name", "Prompt Tracer").release()
|
||||
repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release()
|
||||
|
||||
# Create an initial commit if the repository is newly created
|
||||
if not repo.head.is_valid():
|
||||
repo.index.commit("And then there was a trace")
|
||||
|
||||
# Check out the initial commit
|
||||
initial_commit = repo.commit("HEAD~0")
|
||||
repo.head.reference = initial_commit
|
||||
repo.head.reset(index=True, working_tree=True)
|
||||
|
||||
# Create or switch to user branch from initial commit
|
||||
user_branch = f"u_{uid}"
|
||||
if user_branch not in repo.branches:
|
||||
repo.create_head(user_branch)
|
||||
repo.heads[user_branch].checkout()
|
||||
|
||||
# Create or switch to conversation branch from user branch
|
||||
conv_branch = f"c_{cid}"
|
||||
if conv_branch not in repo.branches:
|
||||
repo.create_head(conv_branch)
|
||||
repo.heads[conv_branch].checkout()
|
||||
|
||||
# Create or switch to message branch from conversation branch
|
||||
msg_branch = f"m_{mid}" if mid else None
|
||||
if msg_branch and msg_branch not in repo.branches:
|
||||
repo.create_head(msg_branch)
|
||||
if msg_branch:
|
||||
repo.heads[msg_branch].checkout()
|
||||
|
||||
# Include file with content to commit
|
||||
files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml}
|
||||
|
||||
# Write files and stage them
|
||||
for filename, content in files_to_commit.items():
|
||||
file_path = os.path.join(repo_path, filename)
|
||||
# Unescape special characters in content for better readability
|
||||
content = content.strip().replace("\\n", "\n").replace("\\t", "\t")
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
repo.index.add([filename])
|
||||
|
||||
# Create commit
|
||||
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
commit_message = f"""
|
||||
{query[:250]}
|
||||
|
||||
Response:
|
||||
---
|
||||
{response[:500]}...
|
||||
|
||||
Metadata
|
||||
---
|
||||
{metadata_yaml}
|
||||
""".strip()
|
||||
|
||||
repo.index.commit(commit_message)
|
||||
|
||||
logger.debug(f"Saved conversation trace to repo at {repo_path}")
|
||||
return repo_path
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add conversation trace to repo: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path="/tmp/promptrace") -> bool:
|
||||
"""
|
||||
Merge the message branch into its parent conversation branch.
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
response: Assistant response
|
||||
tracer: Dictionary containing uid, cid and mid
|
||||
repo_path: Path to the git repository
|
||||
|
||||
Returns:
|
||||
bool: True if merge was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Extract branch names
|
||||
msg_branch = f"m_{tracer['mid']}"
|
||||
conv_branch = f"c_{tracer['cid']}"
|
||||
|
||||
# Infer repository path from environment variable or provided path
|
||||
repo_path = os.getenv("PROMPTRACE_DIR", repo_path)
|
||||
repo = Repo(repo_path)
|
||||
|
||||
# Checkout conversation branch
|
||||
repo.heads[conv_branch].checkout()
|
||||
|
||||
# Create commit message
|
||||
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
commit_message = f"""
|
||||
{query[:250]}
|
||||
|
||||
Response:
|
||||
---
|
||||
{response[:500]}...
|
||||
|
||||
Metadata
|
||||
---
|
||||
{metadata_yaml}
|
||||
""".strip()
|
||||
|
||||
# Merge message branch into conversation branch
|
||||
repo.git.merge(msg_branch, no_ff=True, m=commit_message)
|
||||
|
||||
# Delete message branch after merge
|
||||
repo.delete_head(msg_branch, force=True)
|
||||
|
||||
logger.debug(f"Successfully merged {msg_branch} into {conv_branch}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
|
|
@ -28,6 +28,7 @@ async def text_to_image(
|
|||
send_status_func: Optional[Callable] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
status_code = 200
|
||||
image = None
|
||||
|
@ -68,6 +69,7 @@ async def text_to_image(
|
|||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
if send_status_func:
|
||||
|
|
|
@ -64,6 +64,7 @@ async def search_online(
|
|||
custom_filters: List[str] = [],
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
query += " ".join(custom_filters)
|
||||
if not is_internet_connected():
|
||||
|
@ -73,7 +74,7 @@ async def search_online(
|
|||
|
||||
# Breakdown the query into subqueries to get the correct answer
|
||||
subqueries = await generate_online_subqueries(
|
||||
query, conversation_history, location, user, query_images=query_images, agent=agent
|
||||
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
|
||||
)
|
||||
response_dict = {}
|
||||
|
||||
|
@ -111,7 +112,7 @@ async def search_online(
|
|||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [
|
||||
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent)
|
||||
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
|
||||
for link, data in webpages.items()
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
@ -153,6 +154,7 @@ async def read_webpages(
|
|||
send_status_func: Optional[Callable] = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"Infer web pages to read from the query and extract relevant information from them"
|
||||
logger.info(f"Inferring web pages to read")
|
||||
|
@ -166,7 +168,7 @@ async def read_webpages(
|
|||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls]
|
||||
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
response: Dict[str, Dict] = defaultdict(dict)
|
||||
|
@ -192,7 +194,12 @@ async def read_webpage(
|
|||
|
||||
|
||||
async def read_webpage_and_extract_content(
|
||||
subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None
|
||||
subqueries: set[str],
|
||||
url: str,
|
||||
content: str = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> Tuple[set[str], str, Union[None, str]]:
|
||||
# Select the web scrapers to use for reading the web page
|
||||
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
|
||||
|
@ -214,7 +221,9 @@ async def read_webpage_and_extract_content(
|
|||
# Extract relevant information from the web page
|
||||
if is_none_or_empty(extracted_info):
|
||||
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
||||
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent)
|
||||
extracted_info = await extract_relevant_info(
|
||||
subqueries, content, user=user, agent=agent, tracer=tracer
|
||||
)
|
||||
|
||||
# If we successfully extracted information, break the loop
|
||||
if not is_none_or_empty(extracted_info):
|
||||
|
|
|
@ -350,6 +350,7 @@ async def extract_references_and_questions(
|
|||
send_status_func: Optional[Callable] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
||||
|
@ -425,6 +426,7 @@ async def extract_references_and_questions(
|
|||
user=user,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
personality_context=personality_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = conversation_config.openai_config
|
||||
|
@ -442,6 +444,7 @@ async def extract_references_and_questions(
|
|||
query_images=query_images,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -456,6 +459,7 @@ async def extract_references_and_questions(
|
|||
user=user,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -471,6 +475,7 @@ async def extract_references_and_questions(
|
|||
user=user,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
|
|
|
@ -3,6 +3,7 @@ import base64
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Dict, Optional
|
||||
|
@ -563,6 +564,12 @@ async def chat(
|
|||
event_delimiter = "␃🔚␗"
|
||||
q = unquote(q)
|
||||
nonlocal conversation_id
|
||||
tracer: dict = {
|
||||
"mid": f"{uuid.uuid4()}",
|
||||
"cid": conversation_id,
|
||||
"uid": user.id,
|
||||
"khoj_version": state.khoj_version,
|
||||
}
|
||||
|
||||
uploaded_images: list[str] = []
|
||||
if images:
|
||||
|
@ -682,6 +689,7 @@ async def chat(
|
|||
user=user,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||
async for result in send_event(
|
||||
|
@ -689,7 +697,9 @@ async def chat(
|
|||
):
|
||||
yield result
|
||||
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
|
||||
mode = await aget_relevant_output_modes(
|
||||
q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer
|
||||
)
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||
yield result
|
||||
if mode not in conversation_commands:
|
||||
|
@ -755,6 +765,7 @@ async def chat(
|
|||
query_images=uploaded_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
response_log = str(response)
|
||||
async for result in send_llm_response(response_log):
|
||||
|
@ -774,6 +785,7 @@ async def chat(
|
|||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -795,7 +807,7 @@ async def chat(
|
|||
if ConversationCommand.Automation in conversation_commands:
|
||||
try:
|
||||
automation, crontime, query_to_run, subject = await create_automation(
|
||||
q, timezone, user, request.url, meta_log
|
||||
q, timezone, user, request.url, meta_log, tracer=tracer
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||
|
@ -817,6 +829,7 @@ async def chat(
|
|||
inferred_queries=[query_to_run],
|
||||
automation_id=automation.id,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
async for result in send_llm_response(llm_response):
|
||||
yield result
|
||||
|
@ -838,6 +851,7 @@ async def chat(
|
|||
partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -882,6 +896,7 @@ async def chat(
|
|||
custom_filters,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -906,6 +921,7 @@ async def chat(
|
|||
partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -956,6 +972,7 @@ async def chat(
|
|||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -986,6 +1003,7 @@ async def chat(
|
|||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
content_obj = {
|
||||
"intentType": intent_type,
|
||||
|
@ -1014,6 +1032,7 @@ async def chat(
|
|||
user=user,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -1041,6 +1060,7 @@ async def chat(
|
|||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
|
@ -1064,6 +1084,7 @@ async def chat(
|
|||
location,
|
||||
user_name,
|
||||
uploaded_images,
|
||||
tracer,
|
||||
)
|
||||
|
||||
# Send Response
|
||||
|
|
|
@ -301,6 +301,7 @@ async def aget_relevant_information_sources(
|
|||
user: KhojUser,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||
|
@ -337,6 +338,7 @@ async def aget_relevant_information_sources(
|
|||
relevant_tools_prompt,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -378,6 +380,7 @@ async def aget_relevant_output_modes(
|
|||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||
|
@ -413,7 +416,9 @@ async def aget_relevant_output_modes(
|
|||
)
|
||||
|
||||
with timer("Chat actor: Infer output mode for chat response", logger):
|
||||
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user)
|
||||
response = await send_message_to_model_wrapper(
|
||||
relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer
|
||||
)
|
||||
|
||||
try:
|
||||
response = response.strip()
|
||||
|
@ -444,6 +449,7 @@ async def infer_webpage_urls(
|
|||
user: KhojUser,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
"""
|
||||
Infer webpage links from the given query
|
||||
|
@ -468,7 +474,11 @@ async def infer_webpage_urls(
|
|||
|
||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
online_queries_prompt,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
||||
|
@ -490,6 +500,7 @@ async def generate_online_subqueries(
|
|||
user: KhojUser,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate subqueries from the given query
|
||||
|
@ -514,7 +525,11 @@ async def generate_online_subqueries(
|
|||
|
||||
with timer("Chat actor: Generate online search subqueries", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
online_queries_prompt,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
|
@ -533,7 +548,7 @@ async def generate_online_subqueries(
|
|||
|
||||
|
||||
async def schedule_query(
|
||||
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
|
||||
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
|
||||
) -> Tuple[str, ...]:
|
||||
"""
|
||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||
|
@ -546,7 +561,7 @@ async def schedule_query(
|
|||
)
|
||||
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
crontime_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
|
@ -561,7 +576,7 @@ async def schedule_query(
|
|||
|
||||
|
||||
async def extract_relevant_info(
|
||||
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None
|
||||
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
|
||||
) -> Union[str, None]:
|
||||
"""
|
||||
Extract relevant information for a given query from the target corpus
|
||||
|
@ -584,6 +599,7 @@ async def extract_relevant_info(
|
|||
extract_relevant_information,
|
||||
prompts.system_prompt_extract_relevant_information,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
|
@ -595,6 +611,7 @@ async def extract_relevant_summary(
|
|||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> Union[str, None]:
|
||||
"""
|
||||
Extract relevant information for a given query from the target corpus
|
||||
|
@ -622,6 +639,7 @@ async def extract_relevant_summary(
|
|||
prompts.system_prompt_extract_relevant_summary,
|
||||
user=user,
|
||||
query_images=query_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
|
@ -636,6 +654,7 @@ async def generate_excalidraw_diagram(
|
|||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
if send_status_func:
|
||||
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
|
||||
|
@ -650,6 +669,7 @@ async def generate_excalidraw_diagram(
|
|||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
if send_status_func:
|
||||
|
@ -660,6 +680,7 @@ async def generate_excalidraw_diagram(
|
|||
q=better_diagram_description_prompt,
|
||||
user=user,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
yield better_diagram_description_prompt, excalidraw_diagram_description
|
||||
|
@ -674,6 +695,7 @@ async def generate_better_diagram_description(
|
|||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
"""
|
||||
Generate a diagram description from the given query and context
|
||||
|
@ -711,7 +733,7 @@ async def generate_better_diagram_description(
|
|||
|
||||
with timer("Chat actor: Generate better diagram description", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
improve_diagram_description_prompt, query_images=query_images, user=user
|
||||
improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer
|
||||
)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
|
@ -724,6 +746,7 @@ async def generate_excalidraw_diagram_from_description(
|
|||
q: str,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
|
@ -735,7 +758,9 @@ async def generate_excalidraw_diagram_from_description(
|
|||
)
|
||||
|
||||
with timer("Chat actor: Generate excalidraw diagram", logger):
|
||||
raw_response = await send_message_to_model_wrapper(message=excalidraw_diagram_generation, user=user)
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
message=excalidraw_diagram_generation, user=user, tracer=tracer
|
||||
)
|
||||
raw_response = raw_response.strip()
|
||||
raw_response = remove_json_codeblock(raw_response)
|
||||
response: Dict[str, str] = json.loads(raw_response)
|
||||
|
@ -756,6 +781,7 @@ async def generate_better_image_prompt(
|
|||
query_images: Optional[List[str]] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
"""
|
||||
Generate a better image prompt from the given query
|
||||
|
@ -802,7 +828,9 @@ async def generate_better_image_prompt(
|
|||
)
|
||||
|
||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
|
||||
response = await send_message_to_model_wrapper(
|
||||
image_prompt, query_images=query_images, user=user, tracer=tracer
|
||||
)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
response = response[1:-1]
|
||||
|
@ -816,6 +844,7 @@ async def send_message_to_model_wrapper(
|
|||
response_type: str = "text",
|
||||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
||||
vision_available = conversation_config.vision_enabled
|
||||
|
@ -862,6 +891,7 @@ async def send_message_to_model_wrapper(
|
|||
max_prompt_size=max_tokens,
|
||||
streaming=False,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
|
@ -885,6 +915,7 @@ async def send_message_to_model_wrapper(
|
|||
model=chat_model,
|
||||
response_type=response_type,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -903,6 +934,7 @@ async def send_message_to_model_wrapper(
|
|||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -918,7 +950,7 @@ async def send_message_to_model_wrapper(
|
|||
)
|
||||
|
||||
return gemini_send_message_to_model(
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||
|
@ -929,6 +961,7 @@ def send_message_to_model_wrapper_sync(
|
|||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
user: KhojUser = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
|
||||
|
||||
|
@ -961,6 +994,7 @@ def send_message_to_model_wrapper_sync(
|
|||
max_prompt_size=max_tokens,
|
||||
streaming=False,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
|
@ -975,7 +1009,11 @@ def send_message_to_model_wrapper_sync(
|
|||
)
|
||||
|
||||
openai_response = send_message_to_model(
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
return openai_response
|
||||
|
@ -995,6 +1033,7 @@ def send_message_to_model_wrapper_sync(
|
|||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
|
@ -1013,6 +1052,7 @@ def send_message_to_model_wrapper_sync(
|
|||
api_key=api_key,
|
||||
model=chat_model,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||
|
@ -1032,6 +1072,7 @@ def generate_chat_response(
|
|||
location_data: LocationData = None,
|
||||
user_name: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
tracer: dict = {},
|
||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||
# Initialize Variables
|
||||
chat_response = None
|
||||
|
@ -1051,6 +1092,7 @@ def generate_chat_response(
|
|||
client_application=client_application,
|
||||
conversation_id=conversation_id,
|
||||
query_images=query_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||
|
@ -1077,6 +1119,7 @@ def generate_chat_response(
|
|||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
|
@ -1100,6 +1143,7 @@ def generate_chat_response(
|
|||
user_name=user_name,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
|
@ -1120,6 +1164,7 @@ def generate_chat_response(
|
|||
user_name=user_name,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -1139,6 +1184,7 @@ def generate_chat_response(
|
|||
user_name=user_name,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
metadata.update({"chat_model": conversation_config.chat_model})
|
||||
|
@ -1495,9 +1541,15 @@ def scheduled_chat(
|
|||
|
||||
|
||||
async def create_automation(
|
||||
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None
|
||||
q: str,
|
||||
timezone: str,
|
||||
user: KhojUser,
|
||||
calling_url: URL,
|
||||
meta_log: dict = {},
|
||||
conversation_id: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
crontime, query_to_run, subject = await schedule_query(q, meta_log, user)
|
||||
crontime, query_to_run, subject = await schedule_query(q, meta_log, user, tracer=tracer)
|
||||
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
|
||||
return job, crontime, query_to_run, subject
|
||||
|
||||
|
|
Loading…
Reference in a new issue