Merge branch 'add-prompt-tracer-for-observability' of github.com:khoj-ai/khoj into features/advanced-reasoning

- Start from this branches src/khoj/routers/api_chat.py
    Add tracer to all old and new chat actors that don't have it set
    when they are called.
  - Update the new chat actors like apick next tool etc to use tracer too
This commit is contained in:
Debanjum Singh Solanky 2024-10-26 05:30:24 -07:00
commit f04f871a72
16 changed files with 458 additions and 57 deletions

View file

@ -119,6 +119,7 @@ dev = [
"mypy >= 1.0.1", "mypy >= 1.0.1",
"black >= 23.1.0", "black >= 23.1.0",
"pre-commit >= 3.0.4", "pre-commit >= 3.0.4",
"gitpython ~= 3.1.43",
] ]
[tool.hatch.version] [tool.hatch.version]

View file

@ -34,6 +34,7 @@ def extract_questions_anthropic(
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_enabled: bool = False, vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
tracer: dict = {},
): ):
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@ -89,6 +90,7 @@ def extract_questions_anthropic(
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
api_key=api_key, api_key=api_key,
tracer=tracer,
) )
# Extract, Clean Message from Claude's Response # Extract, Clean Message from Claude's Response
@ -110,7 +112,7 @@ def extract_questions_anthropic(
return questions return questions
def anthropic_send_message_to_model(messages, api_key, model): def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
""" """
Send message to model Send message to model
""" """
@ -122,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model):
system_prompt=system_prompt, system_prompt=system_prompt,
model_name=model, model_name=model,
api_key=api_key, api_key=api_key,
tracer=tracer,
) )
@ -142,6 +145,7 @@ def converse_anthropic(
agent: Agent = None, agent: Agent = None,
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
tracer: dict = {},
): ):
""" """
Converse with user using Anthropic's Claude Converse with user using Anthropic's Claude
@ -220,4 +224,5 @@ def converse_anthropic(
system_prompt=system_prompt, system_prompt=system_prompt,
completion_func=completion_func, completion_func=completion_func,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
tracer=tracer,
) )

View file

@ -12,8 +12,13 @@ from tenacity import (
wait_random_exponential, wait_random_exponential,
) )
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url from khoj.processor.conversation.utils import (
from khoj.utils.helpers import is_none_or_empty ThreadedGenerator,
commit_conversation_trace,
get_image_from_url,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,7 +35,7 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
reraise=True, reraise=True,
) )
def anthropic_completion_with_backoff( def anthropic_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={}
) -> str: ) -> str:
if api_key not in anthropic_clients: if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
@ -58,6 +63,12 @@ def anthropic_completion_with_backoff(
for text in stream.text_stream: for text in stream.text_stream:
aggregated_response += text aggregated_response += text
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
return aggregated_response return aggregated_response
@ -78,18 +89,19 @@ def anthropic_chat_completion_with_backoff(
max_prompt_size=None, max_prompt_size=None,
completion_func=None, completion_func=None,
model_kwargs=None, model_kwargs=None,
tracer={},
): ):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread( t = Thread(
target=anthropic_llm_thread, target=anthropic_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs), args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs, tracer),
) )
t.start() t.start()
return g return g
def anthropic_llm_thread( def anthropic_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None, tracer={}
): ):
try: try:
if api_key not in anthropic_clients: if api_key not in anthropic_clients:
@ -102,6 +114,7 @@ def anthropic_llm_thread(
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
] ]
aggregated_response = ""
with client.messages.stream( with client.messages.stream(
messages=formatted_messages, messages=formatted_messages,
model=model_name, # type: ignore model=model_name, # type: ignore
@ -112,7 +125,14 @@ def anthropic_llm_thread(
**(model_kwargs or dict()), **(model_kwargs or dict()),
) as stream: ) as stream:
for text in stream.text_stream: for text in stream.text_stream:
aggregated_response += text
g.send(text) g.send(text)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e: except Exception as e:
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
finally: finally:

View file

@ -35,6 +35,7 @@ def extract_questions_gemini(
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_enabled: bool = False, vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
tracer: dict = {},
): ):
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@ -85,7 +86,7 @@ def extract_questions_gemini(
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")] messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
response = gemini_send_message_to_model( response = gemini_send_message_to_model(
messages, api_key, model, response_type="json_object", temperature=temperature messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
) )
# Extract, Clean Message from Gemini's Response # Extract, Clean Message from Gemini's Response
@ -107,7 +108,9 @@ def extract_questions_gemini(
return questions return questions
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None): def gemini_send_message_to_model(
messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={}
):
""" """
Send message to model Send message to model
""" """
@ -125,6 +128,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text",
api_key=api_key, api_key=api_key,
temperature=temperature, temperature=temperature,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
tracer=tracer,
) )
@ -146,6 +150,7 @@ def converse_gemini(
agent: Agent = None, agent: Agent = None,
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
tracer={},
): ):
""" """
Converse with user using Google's Gemini Converse with user using Google's Gemini
@ -224,4 +229,5 @@ def converse_gemini(
api_key=api_key, api_key=api_key,
system_prompt=system_prompt, system_prompt=system_prompt,
completion_func=completion_func, completion_func=completion_func,
tracer=tracer,
) )

View file

@ -19,8 +19,13 @@ from tenacity import (
wait_random_exponential, wait_random_exponential,
) )
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url from khoj.processor.conversation.utils import (
from khoj.utils.helpers import is_none_or_empty ThreadedGenerator,
commit_conversation_trace,
get_image_from_url,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,7 +40,7 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192
reraise=True, reraise=True,
) )
def gemini_completion_with_backoff( def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
) -> str: ) -> str:
genai.configure(api_key=api_key) genai.configure(api_key=api_key)
model_kwargs = model_kwargs or dict() model_kwargs = model_kwargs or dict()
@ -60,16 +65,23 @@ def gemini_completion_with_backoff(
try: try:
# Generate the response. The last message is considered to be the current prompt # Generate the response. The last message is considered to be the current prompt
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"]) response = chat_session.send_message(formatted_messages[-1]["parts"])
return aggregated_response.text response_text = response.text
except StopCandidateException as e: except StopCandidateException as e:
response_message, _ = handle_gemini_response(e.args) response_text, _ = handle_gemini_response(e.args)
# Respond with reason for stopping # Respond with reason for stopping
logger.warning( logger.warning(
f"LLM Response Prevented for {model_name}: {response_message}.\n" f"LLM Response Prevented for {model_name}: {response_text}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}" + f"Last Message by {messages[-1].role}: {messages[-1].content}"
) )
return response_message
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, response_text, tracer)
return response_text
@retry( @retry(
@ -88,17 +100,20 @@ def gemini_chat_completion_with_backoff(
system_prompt, system_prompt,
completion_func=None, completion_func=None,
model_kwargs=None, model_kwargs=None,
tracer: dict = {},
): ):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread( t = Thread(
target=gemini_llm_thread, target=gemini_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs), args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
) )
t.start() t.start()
return g return g
def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None): def gemini_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
):
try: try:
genai.configure(api_key=api_key) genai.configure(api_key=api_key)
model_kwargs = model_kwargs or dict() model_kwargs = model_kwargs or dict()
@ -117,16 +132,25 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
}, },
) )
aggregated_response = ""
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages] formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
# all messages up to the last are considered to be part of the chat history # all messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1]) chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt # the last message is considered to be the current prompt
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True): for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text message = message or chunk.text
aggregated_response += message
g.send(message) g.send(message)
if stopped: if stopped:
raise StopCandidateException(message) raise StopCandidateException(message)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
except StopCandidateException as e: except StopCandidateException as e:
logger.warning( logger.warning(
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n" f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"

View file

@ -12,11 +12,12 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ThreadedGenerator, ThreadedGenerator,
commit_conversation_trace,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
) )
from khoj.utils import state from khoj.utils import state
from khoj.utils.constants import empty_escape_sequences from khoj.utils.constants import empty_escape_sequences
from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,6 +35,7 @@ def extract_questions_offline(
max_prompt_size: int = None, max_prompt_size: int = None,
temperature: float = 0.7, temperature: float = 0.7,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
tracer: dict = {},
) -> List[str]: ) -> List[str]:
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@ -94,6 +96,7 @@ def extract_questions_offline(
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
temperature=temperature, temperature=temperature,
response_type="json_object", response_type="json_object",
tracer=tracer,
) )
finally: finally:
state.chat_lock.release() state.chat_lock.release()
@ -147,6 +150,7 @@ def converse_offline(
location_data: LocationData = None, location_data: LocationData = None,
user_name: str = None, user_name: str = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
) -> Union[ThreadedGenerator, Iterator[str]]: ) -> Union[ThreadedGenerator, Iterator[str]]:
""" """
Converse with user using Llama Converse with user using Llama
@ -155,6 +159,7 @@ def converse_offline(
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references}) compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references})
tracer["chat_model"] = model
current_date = datetime.now() current_date = datetime.now()
@ -218,13 +223,14 @@ def converse_offline(
logger.debug(f"Conversation Context for {model}: {truncated_messages}") logger.debug(f"Conversation Context for {model}: {truncated_messages}")
g = ThreadedGenerator(references, online_results, completion_func=completion_func) g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size)) t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
t.start() t.start()
return g return g
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None): def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
stop_phrases = ["<s>", "INST]", "Notes:"] stop_phrases = ["<s>", "INST]", "Notes:"]
aggregated_response = ""
state.chat_lock.acquire() state.chat_lock.acquire()
try: try:
@ -232,7 +238,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
) )
for response in response_iterator: for response in response_iterator:
g.send(response["choices"][0]["delta"].get("content", "")) response_delta = response["choices"][0]["delta"].get("content", "")
aggregated_response += response_delta
g.send(response_delta)
# Save conversation trace
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
finally: finally:
state.chat_lock.release() state.chat_lock.release()
g.close() g.close()
@ -247,6 +260,7 @@ def send_message_to_model_offline(
stop=[], stop=[],
max_prompt_size: int = None, max_prompt_size: int = None,
response_type: str = "text", response_type: str = "text",
tracer: dict = {},
): ):
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
@ -254,7 +268,17 @@ def send_message_to_model_offline(
response = offline_chat_model.create_chat_completion( response = offline_chat_model.create_chat_completion(
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type} messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
) )
if streaming: if streaming:
return response return response
else:
return response["choices"][0]["message"].get("content", "") response_text = response["choices"][0]["message"].get("content", "")
# Save conversation trace for non-streaming responses
# Streamed responses need to be saved by the calling function
tracer["chat_model"] = model
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, response_text, tracer)
return response_text

View file

@ -33,6 +33,7 @@ def extract_questions(
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_enabled: bool = False, vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
tracer: dict = {},
): ):
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@ -82,7 +83,13 @@ def extract_questions(
messages = [ChatMessage(content=prompt, role="user")] messages = [ChatMessage(content=prompt, role="user")]
response = send_message_to_model( response = send_message_to_model(
messages, api_key, model, response_type="json_object", api_base_url=api_base_url, temperature=temperature messages,
api_key,
model,
response_type="json_object",
api_base_url=api_base_url,
temperature=temperature,
tracer=tracer,
) )
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
@ -103,7 +110,9 @@ def extract_questions(
return questions return questions
def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None, temperature=0): def send_message_to_model(
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
):
""" """
Send message to model Send message to model
""" """
@ -116,6 +125,7 @@ def send_message_to_model(messages, api_key, model, response_type="text", api_ba
temperature=temperature, temperature=temperature,
api_base_url=api_base_url, api_base_url=api_base_url,
model_kwargs={"response_format": {"type": response_type}}, model_kwargs={"response_format": {"type": response_type}},
tracer=tracer,
) )
@ -138,6 +148,7 @@ def converse(
agent: Agent = None, agent: Agent = None,
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
tracer: dict = {},
): ):
""" """
Converse with user using OpenAI's ChatGPT Converse with user using OpenAI's ChatGPT
@ -214,4 +225,5 @@ def converse(
api_base_url=api_base_url, api_base_url=api_base_url,
completion_func=completion_func, completion_func=completion_func,
model_kwargs={"stop": ["Notes:\n["]}, model_kwargs={"stop": ["Notes:\n["]},
tracer=tracer,
) )

View file

@ -12,7 +12,12 @@ from tenacity import (
wait_random_exponential, wait_random_exponential,
) )
from khoj.processor.conversation.utils import ThreadedGenerator from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,7 +38,7 @@ openai_clients: Dict[str, openai.OpenAI] = {}
reraise=True, reraise=True,
) )
def completion_with_backoff( def completion_with_backoff(
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {}
) -> str: ) -> str:
client_key = f"{openai_api_key}--{api_base_url}" client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI | None = openai_clients.get(client_key) client: openai.OpenAI | None = openai_clients.get(client_key)
@ -77,6 +82,12 @@ def completion_with_backoff(
elif delta_chunk.content: elif delta_chunk.content:
aggregated_response += delta_chunk.content aggregated_response += delta_chunk.content
# Save conversation trace
tracer["chat_model"] = model
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
return aggregated_response return aggregated_response
@ -103,26 +114,37 @@ def chat_completion_with_backoff(
api_base_url=None, api_base_url=None,
completion_func=None, completion_func=None,
model_kwargs=None, model_kwargs=None,
tracer: dict = {},
): ):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread( t = Thread(
target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs) target=llm_thread,
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer),
) )
t.start() t.start()
return g return g
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None): def llm_thread(
g,
messages,
model_name,
temperature,
openai_api_key=None,
api_base_url=None,
model_kwargs=None,
tracer: dict = {},
):
try: try:
client_key = f"{openai_api_key}--{api_base_url}" client_key = f"{openai_api_key}--{api_base_url}"
if client_key not in openai_clients: if client_key not in openai_clients:
client: openai.OpenAI = openai.OpenAI( client = openai.OpenAI(
api_key=openai_api_key, api_key=openai_api_key,
base_url=api_base_url, base_url=api_base_url,
) )
openai_clients[client_key] = client openai_clients[client_key] = client
else: else:
client: openai.OpenAI = openai_clients[client_key] client = openai_clients[client_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages] formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
stream = True stream = True
@ -144,17 +166,29 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
**(model_kwargs or dict()), **(model_kwargs or dict()),
) )
aggregated_response = ""
if not stream: if not stream:
g.send(chat.choices[0].message.content) aggregated_response = chat.choices[0].message.content
g.send(aggregated_response)
else: else:
for chunk in chat: for chunk in chat:
if len(chunk.choices) == 0: if len(chunk.choices) == 0:
continue continue
delta_chunk = chunk.choices[0].delta delta_chunk = chunk.choices[0].delta
text_chunk = ""
if isinstance(delta_chunk, str): if isinstance(delta_chunk, str):
g.send(delta_chunk) text_chunk = delta_chunk
elif delta_chunk.content: elif delta_chunk.content:
g.send(delta_chunk.content) text_chunk = delta_chunk.content
if text_chunk:
aggregated_response += text_chunk
g.send(text_chunk)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e: except Exception as e:
logger.error(f"Error in llm_thread: {e}", exc_info=True) logger.error(f"Error in llm_thread: {e}", exc_info=True)
finally: finally:

View file

@ -2,6 +2,7 @@ import base64
import logging import logging
import math import math
import mimetypes import mimetypes
import os
import queue import queue
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
@ -13,6 +14,8 @@ from typing import Any, Dict, List, Optional
import PIL.Image import PIL.Image
import requests import requests
import tiktoken import tiktoken
import yaml
from git import Repo
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from llama_cpp.llama import Llama from llama_cpp.llama import Llama
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -24,7 +27,7 @@ from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.word_filter import WordFilter
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import is_none_or_empty, merge_dicts from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
model_to_prompt_size = { model_to_prompt_size = {
@ -178,6 +181,7 @@ def save_to_conversation_log(
conversation_id: str = None, conversation_id: str = None,
automation_id: str = None, automation_id: str = None,
query_images: List[str] = None, query_images: List[str] = None,
tracer: Dict[str, Any] = {},
): ):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log( updated_conversation = message_to_log(
@ -204,6 +208,9 @@ def save_to_conversation_log(
user_message=q, user_message=q,
) )
if in_debug_mode() or state.verbose > 1:
merge_message_into_conversation_trace(q, chat_response, tracer)
logger.info( logger.info(
f""" f"""
Saved Conversation Turn Saved Conversation Turn
@ -415,3 +422,160 @@ def get_image_from_url(image_url: str, type="pil"):
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}") logger.error(f"Failed to get image from URL {image_url}: {e}")
return ImageWithType(content=None, type=None) return ImageWithType(content=None, type=None)
def commit_conversation_trace(
session: list[ChatMessage],
response: str | list[dict],
tracer: dict,
system_message: str | list[dict] = "",
repo_path: str = "/tmp/khoj_promptrace",
) -> str:
"""
Save trace of conversation step using git. Useful to visualize, compare and debug traces.
Returns the path to the repository.
"""
# Serialize session, system message and response to yaml
system_message_yaml = yaml.dump(system_message, allow_unicode=True, sort_keys=False, default_flow_style=False)
response_yaml = yaml.dump(response, allow_unicode=True, sort_keys=False, default_flow_style=False)
formatted_session = [{"role": message.role, "content": message.content} for message in session]
session_yaml = yaml.dump(formatted_session, allow_unicode=True, sort_keys=False, default_flow_style=False)
query = (
yaml.dump(session[-1].content, allow_unicode=True, sort_keys=False, default_flow_style=False)
.strip()
.removeprefix("'")
.removesuffix("'")
) # Extract serialized query from chat session
# Extract chat metadata for session
uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid")
# Infer repository path from environment variable or provided path
repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace"
try:
# Prepare git repository
os.makedirs(repo_path, exist_ok=True)
repo = Repo.init(repo_path)
# Remove post-commit hook if it exists
hooks_dir = os.path.join(repo_path, ".git", "hooks")
post_commit_hook = os.path.join(hooks_dir, "post-commit")
if os.path.exists(post_commit_hook):
os.remove(post_commit_hook)
# Configure git user if not set
if not repo.config_reader().has_option("user", "email"):
repo.config_writer().set_value("user", "name", "Prompt Tracer").release()
repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release()
# Create an initial commit if the repository is newly created
if not repo.head.is_valid():
repo.index.commit("And then there was a trace")
# Check out the initial commit
initial_commit = repo.commit("HEAD~0")
repo.head.reference = initial_commit
repo.head.reset(index=True, working_tree=True)
# Create or switch to user branch from initial commit
user_branch = f"u_{uid}"
if user_branch not in repo.branches:
repo.create_head(user_branch)
repo.heads[user_branch].checkout()
# Create or switch to conversation branch from user branch
conv_branch = f"c_{cid}"
if conv_branch not in repo.branches:
repo.create_head(conv_branch)
repo.heads[conv_branch].checkout()
# Create or switch to message branch from conversation branch
msg_branch = f"m_{mid}" if mid else None
if msg_branch and msg_branch not in repo.branches:
repo.create_head(msg_branch)
repo.heads[msg_branch].checkout()
# Include file with content to commit
files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml}
# Write files and stage them
for filename, content in files_to_commit.items():
file_path = os.path.join(repo_path, filename)
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
repo.index.add([filename])
# Create commit
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
commit_message = f"""
{query[:250]}
Response:
---
{response[:500]}...
Metadata
---
{metadata_yaml}
""".strip()
repo.index.commit(commit_message)
logger.debug(f"Saved conversation trace to repo at {repo_path}")
return repo_path
except Exception as e:
logger.error(f"Failed to add conversation trace to repo: {str(e)}")
return None
def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path=None) -> bool:
"""
Merge the message branch into its parent conversation branch.
Args:
query: User query
response: Assistant response
tracer: Dictionary containing uid, cid and mid
repo_path: Path to the git repository
Returns:
bool: True if merge was successful, False otherwise
"""
try:
# Infer repository path from environment variable or provided path
repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace"
repo = Repo(repo_path)
# Extract branch names
msg_branch = f"m_{tracer['mid']}"
conv_branch = f"c_{tracer['cid']}"
# Checkout conversation branch
repo.heads[conv_branch].checkout()
# Create commit message
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
commit_message = f"""
{query[:250]}
Response:
---
{response[:500]}...
Metadata
---
{metadata_yaml}
""".strip()
# Merge message branch into conversation branch
repo.git.merge(msg_branch, no_ff=True, m=commit_message)
# Delete message branch after merge
repo.delete_head(msg_branch, force=True)
logger.debug(f"Successfully merged {msg_branch} into {conv_branch}")
return True
except Exception as e:
logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}")
return False

View file

@ -28,6 +28,7 @@ async def text_to_image(
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
): ):
status_code = 200 status_code = 200
image = None image = None
@ -68,6 +69,7 @@ async def text_to_image(
query_images=query_images, query_images=query_images,
user=user, user=user,
agent=agent, agent=agent,
tracer=tracer,
) )
if send_status_func: if send_status_func:

View file

@ -66,6 +66,7 @@ async def search_online(
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ, max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
): ):
query += " ".join(custom_filters) query += " ".join(custom_filters)
if not is_internet_connected(): if not is_internet_connected():
@ -75,7 +76,7 @@ async def search_online(
# Breakdown the query into subqueries to get the correct answer # Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries( subqueries = await generate_online_subqueries(
query, conversation_history, location, user, query_images=query_images, agent=agent query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
) )
response_dict = {} response_dict = {}
@ -113,7 +114,7 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
tasks = [ tasks = [
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent) read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
for link, data in webpages.items() for link, data in webpages.items()
] ]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
@ -155,6 +156,7 @@ async def read_webpages(
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
): ):
"Infer web pages to read from the query and extract relevant information from them" "Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read") logger.info(f"Inferring web pages to read")
@ -168,7 +170,7 @@ async def read_webpages(
webpage_links_str = "\n- " + "\n- ".join(list(urls)) webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls] tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict) response: Dict[str, Dict] = defaultdict(dict)
@ -194,7 +196,12 @@ async def read_webpage(
async def read_webpage_and_extract_content( async def read_webpage_and_extract_content(
subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None subqueries: set[str],
url: str,
content: str = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> Tuple[set[str], str, Union[None, str]]: ) -> Tuple[set[str], str, Union[None, str]]:
# Select the web scrapers to use for reading the web page # Select the web scrapers to use for reading the web page
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers() web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
@ -216,7 +223,9 @@ async def read_webpage_and_extract_content(
# Extract relevant information from the web page # Extract relevant information from the web page
if is_none_or_empty(extracted_info): if is_none_or_empty(extracted_info):
with timer(f"Extracting relevant information from web page at '{url}' took", logger): with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent) extracted_info = await extract_relevant_info(
subqueries, content, user=user, agent=agent, tracer=tracer
)
# If we successfully extracted information, break the loop # If we successfully extracted information, break the loop
if not is_none_or_empty(extracted_info): if not is_none_or_empty(extracted_info):

View file

@ -35,6 +35,7 @@ async def run_code(
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
sandbox_url: str = SANDBOX_URL, sandbox_url: str = SANDBOX_URL,
tracer: dict = {},
): ):
# Generate Code # Generate Code
if send_status_func: if send_status_func:
@ -43,7 +44,14 @@ async def run_code(
try: try:
with timer("Chat actor: Generate programs to execute", logger): with timer("Chat actor: Generate programs to execute", logger):
codes = await generate_python_code( codes = await generate_python_code(
query, conversation_history, previous_iterations_history, location_data, user, query_images, agent query,
conversation_history,
previous_iterations_history,
location_data,
user,
query_images,
agent,
tracer,
) )
except Exception as e: except Exception as e:
raise ValueError(f"Failed to generate code for {query} with error: {e}") raise ValueError(f"Failed to generate code for {query} with error: {e}")
@ -72,6 +80,7 @@ async def generate_python_code(
user: KhojUser, user: KhojUser,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
) -> List[str]: ) -> List[str]:
location = f"{location_data}" if location_data else "Unknown" location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
@ -98,6 +107,7 @@ async def generate_python_code(
query_images=query_images, query_images=query_images,
response_type="json_object", response_type="json_object",
user=user, user=user,
tracer=tracer,
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list

View file

@ -351,6 +351,7 @@ async def extract_references_and_questions(
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
): ):
user = request.user.object if request.user.is_authenticated else None user = request.user.object if request.user.is_authenticated else None
@ -424,6 +425,7 @@ async def extract_references_and_questions(
user=user, user=user,
max_prompt_size=conversation_config.max_prompt_size, max_prompt_size=conversation_config.max_prompt_size,
personality_context=personality_context, personality_context=personality_context,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config openai_chat_config = conversation_config.openai_config
@ -441,6 +443,7 @@ async def extract_references_and_questions(
query_images=query_images, query_images=query_images,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key
@ -455,6 +458,7 @@ async def extract_references_and_questions(
user=user, user=user,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key
@ -470,6 +474,7 @@ async def extract_references_and_questions(
user=user, user=user,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,
tracer=tracer,
) )
# Collate search results as context for GPT # Collate search results as context for GPT

View file

@ -3,6 +3,7 @@ import base64
import json import json
import logging import logging
import time import time
import uuid
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -570,6 +571,12 @@ async def chat(
event_delimiter = "␃🔚␗" event_delimiter = "␃🔚␗"
q = unquote(q) q = unquote(q)
nonlocal conversation_id nonlocal conversation_id
tracer: dict = {
"mid": f"{uuid.uuid4()}",
"cid": conversation_id,
"uid": user.id,
"khoj_version": state.khoj_version,
}
uploaded_images: list[str] = [] uploaded_images: list[str] = []
if images: if images:
@ -703,6 +710,7 @@ async def chat(
user_name=user_name, user_name=user_name,
location=location, location=location,
file_filters=conversation.file_filters if conversation else [], file_filters=conversation.file_filters if conversation else [],
tracer=tracer,
): ):
if isinstance(research_result, InformationCollectionIteration): if isinstance(research_result, InformationCollectionIteration):
if research_result.summarizedResult: if research_result.summarizedResult:
@ -732,9 +740,12 @@ async def chat(
user=user, user=user,
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer,
) )
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent) mode = await aget_relevant_output_modes(
q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer
)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result yield result
if mode not in conversation_commands: if mode not in conversation_commands:
@ -778,6 +789,7 @@ async def chat(
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
tracer=tracer,
): ):
if isinstance(response, dict) and ChatEvent.STATUS in response: if isinstance(response, dict) and ChatEvent.STATUS in response:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -796,6 +808,7 @@ async def chat(
client_application=request.user.client_app, client_application=request.user.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer,
) )
return return
@ -817,7 +830,7 @@ async def chat(
if ConversationCommand.Automation in conversation_commands: if ConversationCommand.Automation in conversation_commands:
try: try:
automation, crontime, query_to_run, subject = await create_automation( automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, request.url, meta_log q, timezone, user, request.url, meta_log, tracer=tracer
) )
except Exception as e: except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}") logger.error(f"Error scheduling task {q} for {user.email}: {e}")
@ -839,6 +852,7 @@ async def chat(
inferred_queries=[query_to_run], inferred_queries=[query_to_run],
automation_id=automation.id, automation_id=automation.id,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer,
) )
async for result in send_llm_response(llm_response): async for result in send_llm_response(llm_response):
yield result yield result
@ -860,6 +874,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -905,6 +920,7 @@ async def chat(
custom_filters, custom_filters,
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -930,6 +946,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -984,6 +1001,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -1010,6 +1028,7 @@ async def chat(
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -1040,6 +1059,7 @@ async def chat(
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer,
) )
content_obj = { content_obj = {
"intentType": intent_type, "intentType": intent_type,
@ -1068,6 +1088,7 @@ async def chat(
user=user, user=user,
agent=agent, agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -1095,6 +1116,7 @@ async def chat(
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer,
) )
async for result in send_llm_response(json.dumps(content_obj)): async for result in send_llm_response(json.dumps(content_obj)):
@ -1120,6 +1142,7 @@ async def chat(
user_name, user_name,
researched_results, researched_results,
uploaded_images, uploaded_images,
tracer,
) )
# Send Response # Send Response

View file

@ -291,6 +291,7 @@ async def aget_relevant_information_sources(
user: KhojUser, user: KhojUser,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
): ):
""" """
Given a query, determine which of the available tools the agent should use in order to answer appropriately. Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@ -327,6 +328,7 @@ async def aget_relevant_information_sources(
relevant_tools_prompt, relevant_tools_prompt,
response_type="json_object", response_type="json_object",
user=user, user=user,
tracer=tracer,
) )
try: try:
@ -368,6 +370,7 @@ async def aget_relevant_output_modes(
user: KhojUser = None, user: KhojUser = None,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
): ):
""" """
Given a query, determine which of the available tools the agent should use in order to answer appropriately. Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@ -403,7 +406,9 @@ async def aget_relevant_output_modes(
) )
with timer("Chat actor: Infer output mode for chat response", logger): with timer("Chat actor: Infer output mode for chat response", logger):
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user) response = await send_message_to_model_wrapper(
relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer
)
try: try:
response = response.strip() response = response.strip()
@ -434,6 +439,7 @@ async def infer_webpage_urls(
user: KhojUser, user: KhojUser,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
) -> List[str]: ) -> List[str]:
""" """
Infer webpage links from the given query Infer webpage links from the given query
@ -458,7 +464,11 @@ async def infer_webpage_urls(
with timer("Chat actor: Infer webpage urls to read", logger): with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
online_queries_prompt, query_images=query_images, response_type="json_object", user=user online_queries_prompt,
query_images=query_images,
response_type="json_object",
user=user,
tracer=tracer,
) )
# Validate that the response is a non-empty, JSON-serializable list of URLs # Validate that the response is a non-empty, JSON-serializable list of URLs
@ -481,6 +491,7 @@ async def generate_online_subqueries(
user: KhojUser, user: KhojUser,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
) -> List[str]: ) -> List[str]:
""" """
Generate subqueries from the given query Generate subqueries from the given query
@ -505,7 +516,11 @@ async def generate_online_subqueries(
with timer("Chat actor: Generate online search subqueries", logger): with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
online_queries_prompt, query_images=query_images, response_type="json_object", user=user online_queries_prompt,
query_images=query_images,
response_type="json_object",
user=user,
tracer=tracer,
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
@ -524,7 +539,7 @@ async def generate_online_subqueries(
async def schedule_query( async def schedule_query(
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
) -> Tuple[str, ...]: ) -> Tuple[str, ...]:
""" """
Schedule the date, time to run the query. Assume the server timezone is UTC. Schedule the date, time to run the query. Assume the server timezone is UTC.
@ -537,7 +552,7 @@ async def schedule_query(
) )
raw_response = await send_message_to_model_wrapper( raw_response = await send_message_to_model_wrapper(
crontime_prompt, query_images=query_images, response_type="json_object", user=user crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
@ -552,7 +567,7 @@ async def schedule_query(
async def extract_relevant_info( async def extract_relevant_info(
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
) -> Union[str, None]: ) -> Union[str, None]:
""" """
Extract relevant information for a given query from the target corpus Extract relevant information for a given query from the target corpus
@ -575,6 +590,7 @@ async def extract_relevant_info(
extract_relevant_information, extract_relevant_information,
prompts.system_prompt_extract_relevant_information, prompts.system_prompt_extract_relevant_information,
user=user, user=user,
tracer=tracer,
) )
return response.strip() return response.strip()
@ -586,6 +602,7 @@ async def extract_relevant_summary(
query_images: List[str] = None, query_images: List[str] = None,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
) -> Union[str, None]: ) -> Union[str, None]:
""" """
Extract relevant information for a given query from the target corpus Extract relevant information for a given query from the target corpus
@ -613,6 +630,7 @@ async def extract_relevant_summary(
prompts.system_prompt_extract_relevant_summary, prompts.system_prompt_extract_relevant_summary,
user=user, user=user,
query_images=query_images, query_images=query_images,
tracer=tracer,
) )
return response.strip() return response.strip()
@ -625,6 +643,7 @@ async def generate_summary_from_files(
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
tracer: dict = {},
): ):
try: try:
file_object = None file_object = None
@ -653,6 +672,7 @@ async def generate_summary_from_files(
query_images=query_images, query_images=query_images,
user=user, user=user,
agent=agent, agent=agent,
tracer=tracer,
) )
response_log = str(response) response_log = str(response)
@ -673,6 +693,7 @@ async def generate_excalidraw_diagram(
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
tracer: dict = {},
): ):
if send_status_func: if send_status_func:
async for event in send_status_func("**Enhancing the Diagramming Prompt**"): async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
@ -687,6 +708,7 @@ async def generate_excalidraw_diagram(
query_images=query_images, query_images=query_images,
user=user, user=user,
agent=agent, agent=agent,
tracer=tracer,
) )
if send_status_func: if send_status_func:
@ -697,6 +719,7 @@ async def generate_excalidraw_diagram(
q=better_diagram_description_prompt, q=better_diagram_description_prompt,
user=user, user=user,
agent=agent, agent=agent,
tracer=tracer,
) )
yield better_diagram_description_prompt, excalidraw_diagram_description yield better_diagram_description_prompt, excalidraw_diagram_description
@ -711,6 +734,7 @@ async def generate_better_diagram_description(
query_images: List[str] = None, query_images: List[str] = None,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
) -> str: ) -> str:
""" """
Generate a diagram description from the given query and context Generate a diagram description from the given query and context
@ -748,7 +772,7 @@ async def generate_better_diagram_description(
with timer("Chat actor: Generate better diagram description", logger): with timer("Chat actor: Generate better diagram description", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
improve_diagram_description_prompt, query_images=query_images, user=user improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer
) )
response = response.strip() response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")): if response.startswith(('"', "'")) and response.endswith(('"', "'")):
@ -761,6 +785,7 @@ async def generate_excalidraw_diagram_from_description(
q: str, q: str,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
) -> str: ) -> str:
personality_context = ( personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
@ -772,7 +797,9 @@ async def generate_excalidraw_diagram_from_description(
) )
with timer("Chat actor: Generate excalidraw diagram", logger): with timer("Chat actor: Generate excalidraw diagram", logger):
raw_response = await send_message_to_model_wrapper(message=excalidraw_diagram_generation, user=user) raw_response = await send_message_to_model_wrapper(
message=excalidraw_diagram_generation, user=user, tracer=tracer
)
raw_response = raw_response.strip() raw_response = raw_response.strip()
raw_response = remove_json_codeblock(raw_response) raw_response = remove_json_codeblock(raw_response)
response: Dict[str, str] = json.loads(raw_response) response: Dict[str, str] = json.loads(raw_response)
@ -793,6 +820,7 @@ async def generate_better_image_prompt(
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
) -> str: ) -> str:
""" """
Generate a better image prompt from the given query Generate a better image prompt from the given query
@ -839,7 +867,9 @@ async def generate_better_image_prompt(
) )
with timer("Chat actor: Generate contextual image prompt", logger): with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user) response = await send_message_to_model_wrapper(
image_prompt, query_images=query_images, user=user, tracer=tracer
)
response = response.strip() response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")): if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1] response = response[1:-1]
@ -853,6 +883,7 @@ async def send_message_to_model_wrapper(
response_type: str = "text", response_type: str = "text",
user: KhojUser = None, user: KhojUser = None,
query_images: List[str] = None, query_images: List[str] = None,
tracer: dict = {},
): ):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled vision_available = conversation_config.vision_enabled
@ -899,6 +930,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
streaming=False, streaming=False,
response_type=response_type, response_type=response_type,
tracer=tracer,
) )
elif model_type == ChatModelOptions.ModelType.OPENAI: elif model_type == ChatModelOptions.ModelType.OPENAI:
@ -922,6 +954,7 @@ async def send_message_to_model_wrapper(
model=chat_model, model=chat_model,
response_type=response_type, response_type=response_type,
api_base_url=api_base_url, api_base_url=api_base_url,
tracer=tracer,
) )
elif model_type == ChatModelOptions.ModelType.ANTHROPIC: elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key
@ -940,6 +973,7 @@ async def send_message_to_model_wrapper(
messages=truncated_messages, messages=truncated_messages,
api_key=api_key, api_key=api_key,
model=chat_model, model=chat_model,
tracer=tracer,
) )
elif model_type == ChatModelOptions.ModelType.GOOGLE: elif model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key
@ -955,7 +989,7 @@ async def send_message_to_model_wrapper(
) )
return gemini_send_message_to_model( return gemini_send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer
) )
else: else:
raise HTTPException(status_code=500, detail="Invalid conversation config") raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -966,6 +1000,7 @@ def send_message_to_model_wrapper_sync(
system_message: str = "", system_message: str = "",
response_type: str = "text", response_type: str = "text",
user: KhojUser = None, user: KhojUser = None,
tracer: dict = {},
): ):
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user) conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
@ -998,6 +1033,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
streaming=False, streaming=False,
response_type=response_type, response_type=response_type,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@ -1012,7 +1048,11 @@ def send_message_to_model_wrapper_sync(
) )
openai_response = send_message_to_model( openai_response = send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
) )
return openai_response return openai_response
@ -1032,6 +1072,7 @@ def send_message_to_model_wrapper_sync(
messages=truncated_messages, messages=truncated_messages,
api_key=api_key, api_key=api_key,
model=chat_model, model=chat_model,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
@ -1050,6 +1091,7 @@ def send_message_to_model_wrapper_sync(
api_key=api_key, api_key=api_key,
model=chat_model, model=chat_model,
response_type=response_type, response_type=response_type,
tracer=tracer,
) )
else: else:
raise HTTPException(status_code=500, detail="Invalid conversation config") raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -1071,6 +1113,7 @@ def generate_chat_response(
user_name: Optional[str] = None, user_name: Optional[str] = None,
meta_research: str = "", meta_research: str = "",
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables # Initialize Variables
chat_response = None chat_response = None
@ -1094,6 +1137,7 @@ def generate_chat_response(
client_application=client_application, client_application=client_application,
conversation_id=conversation_id, conversation_id=conversation_id,
query_images=query_images, query_images=query_images,
tracer=tracer,
) )
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
@ -1120,6 +1164,7 @@ def generate_chat_response(
location_data=location_data, location_data=location_data,
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@ -1144,6 +1189,7 @@ def generate_chat_response(
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
vision_available=vision_available, vision_available=vision_available,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
@ -1165,6 +1211,7 @@ def generate_chat_response(
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
vision_available=vision_available, vision_available=vision_available,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key
@ -1184,6 +1231,7 @@ def generate_chat_response(
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
vision_available=vision_available, vision_available=vision_available,
tracer=tracer,
) )
metadata.update({"chat_model": conversation_config.chat_model}) metadata.update({"chat_model": conversation_config.chat_model})
@ -1540,9 +1588,15 @@ def scheduled_chat(
async def create_automation( async def create_automation(
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None q: str,
timezone: str,
user: KhojUser,
calling_url: URL,
meta_log: dict = {},
conversation_id: str = None,
tracer: dict = {},
): ):
crontime, query_to_run, subject = await schedule_query(q, meta_log, user) crontime, query_to_run, subject = await schedule_query(q, meta_log, user, tracer=tracer)
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id) job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
return job, crontime, query_to_run, subject return job, crontime, query_to_run, subject

View file

@ -45,6 +45,7 @@ async def apick_next_tool(
previous_iterations_history: str = None, previous_iterations_history: str = None,
max_iterations: int = 5, max_iterations: int = 5,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
tracer: dict = {},
): ):
""" """
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer. Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
@ -93,6 +94,7 @@ async def apick_next_tool(
response_type="json_object", response_type="json_object",
user=user, user=user,
query_images=query_images, query_images=query_images,
tracer=tracer,
) )
try: try:
@ -135,6 +137,7 @@ async def execute_information_collection(
user_name: str = None, user_name: str = None,
location: LocationData = None, location: LocationData = None,
file_filters: List[str] = [], file_filters: List[str] = [],
tracer: dict = {},
): ):
current_iteration = 0 current_iteration = 0
MAX_ITERATIONS = 5 MAX_ITERATIONS = 5
@ -159,6 +162,7 @@ async def execute_information_collection(
previous_iterations_history, previous_iterations_history,
MAX_ITERATIONS, MAX_ITERATIONS,
send_status_func, send_status_func,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -180,6 +184,7 @@ async def execute_information_collection(
send_status_func, send_status_func,
query_images, query_images,
agent=agent, agent=agent,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -211,6 +216,7 @@ async def execute_information_collection(
max_webpages_to_read=0, max_webpages_to_read=0,
query_images=query_images, query_images=query_images,
agent=agent, agent=agent,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -228,6 +234,7 @@ async def execute_information_collection(
send_status_func, send_status_func,
query_images=query_images, query_images=query_images,
agent=agent, agent=agent,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -258,6 +265,7 @@ async def execute_information_collection(
send_status_func, send_status_func,
query_images=query_images, query_images=query_images,
agent=agent, agent=agent,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]