mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Merge branch 'add-prompt-tracer-for-observability' of github.com:khoj-ai/khoj into features/advanced-reasoning
- Start from this branches src/khoj/routers/api_chat.py Add tracer to all old and new chat actors that don't have it set when they are called. - Update the new chat actors like apick next tool etc to use tracer too
This commit is contained in:
commit
f04f871a72
16 changed files with 458 additions and 57 deletions
|
@ -119,6 +119,7 @@ dev = [
|
|||
"mypy >= 1.0.1",
|
||||
"black >= 23.1.0",
|
||||
"pre-commit >= 3.0.4",
|
||||
"gitpython ~= 3.1.43",
|
||||
]
|
||||
|
||||
[tool.hatch.version]
|
||||
|
|
|
@ -34,6 +34,7 @@ def extract_questions_anthropic(
|
|||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -89,6 +90,7 @@ def extract_questions_anthropic(
|
|||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Claude's Response
|
||||
|
@ -110,7 +112,7 @@ def extract_questions_anthropic(
|
|||
return questions
|
||||
|
||||
|
||||
def anthropic_send_message_to_model(messages, api_key, model):
|
||||
def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
@ -122,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model):
|
|||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
|
@ -142,6 +145,7 @@ def converse_anthropic(
|
|||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Converse with user using Anthropic's Claude
|
||||
|
@ -220,4 +224,5 @@ def converse_anthropic(
|
|||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
|
|
@ -12,8 +12,13 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -30,7 +35,7 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
|
|||
reraise=True,
|
||||
)
|
||||
def anthropic_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={}
|
||||
) -> str:
|
||||
if api_key not in anthropic_clients:
|
||||
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
||||
|
@ -58,6 +63,12 @@ def anthropic_completion_with_backoff(
|
|||
for text in stream.text_stream:
|
||||
aggregated_response += text
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
return aggregated_response
|
||||
|
||||
|
||||
|
@ -78,18 +89,19 @@ def anthropic_chat_completion_with_backoff(
|
|||
max_prompt_size=None,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
tracer={},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=anthropic_llm_thread,
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs, tracer),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def anthropic_llm_thread(
|
||||
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
|
||||
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None, tracer={}
|
||||
):
|
||||
try:
|
||||
if api_key not in anthropic_clients:
|
||||
|
@ -102,6 +114,7 @@ def anthropic_llm_thread(
|
|||
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
|
||||
]
|
||||
|
||||
aggregated_response = ""
|
||||
with client.messages.stream(
|
||||
messages=formatted_messages,
|
||||
model=model_name, # type: ignore
|
||||
|
@ -112,7 +125,14 @@ def anthropic_llm_thread(
|
|||
**(model_kwargs or dict()),
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
aggregated_response += text
|
||||
g.send(text)
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
|
||||
finally:
|
||||
|
|
|
@ -35,6 +35,7 @@ def extract_questions_gemini(
|
|||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -85,7 +86,7 @@ def extract_questions_gemini(
|
|||
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
|
||||
|
||||
response = gemini_send_message_to_model(
|
||||
messages, api_key, model, response_type="json_object", temperature=temperature
|
||||
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Gemini's Response
|
||||
|
@ -107,7 +108,9 @@ def extract_questions_gemini(
|
|||
return questions
|
||||
|
||||
|
||||
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
|
||||
def gemini_send_message_to_model(
|
||||
messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={}
|
||||
):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
@ -125,6 +128,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text",
|
|||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
model_kwargs=model_kwargs,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
|
@ -146,6 +150,7 @@ def converse_gemini(
|
|||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
tracer={},
|
||||
):
|
||||
"""
|
||||
Converse with user using Google's Gemini
|
||||
|
@ -224,4 +229,5 @@ def converse_gemini(
|
|||
api_key=api_key,
|
||||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
|
|
@ -19,8 +19,13 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -35,7 +40,7 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192
|
|||
reraise=True,
|
||||
)
|
||||
def gemini_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
|
||||
) -> str:
|
||||
genai.configure(api_key=api_key)
|
||||
model_kwargs = model_kwargs or dict()
|
||||
|
@ -60,16 +65,23 @@ def gemini_completion_with_backoff(
|
|||
|
||||
try:
|
||||
# Generate the response. The last message is considered to be the current prompt
|
||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||
return aggregated_response.text
|
||||
response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||
response_text = response.text
|
||||
except StopCandidateException as e:
|
||||
response_message, _ = handle_gemini_response(e.args)
|
||||
response_text, _ = handle_gemini_response(e.args)
|
||||
# Respond with reason for stopping
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {response_message}.\n"
|
||||
f"LLM Response Prevented for {model_name}: {response_text}.\n"
|
||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||
)
|
||||
return response_message
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, response_text, tracer)
|
||||
|
||||
return response_text
|
||||
|
||||
|
||||
@retry(
|
||||
|
@ -88,17 +100,20 @@ def gemini_chat_completion_with_backoff(
|
|||
system_prompt,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=gemini_llm_thread,
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs),
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None):
|
||||
def gemini_llm_thread(
|
||||
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
|
||||
):
|
||||
try:
|
||||
genai.configure(api_key=api_key)
|
||||
model_kwargs = model_kwargs or dict()
|
||||
|
@ -117,16 +132,25 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
|
|||
},
|
||||
)
|
||||
|
||||
aggregated_response = ""
|
||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||
|
||||
# all messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
# the last message is considered to be the current prompt
|
||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
|
||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||
message = message or chunk.text
|
||||
aggregated_response += message
|
||||
g.send(message)
|
||||
if stopped:
|
||||
raise StopCandidateException(message)
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except StopCandidateException as e:
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
|
||||
|
|
|
@ -12,11 +12,12 @@ from khoj.processor.conversation import prompts
|
|||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
generate_chatml_messages_with_context,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -34,6 +35,7 @@ def extract_questions_offline(
|
|||
max_prompt_size: int = None,
|
||||
temperature: float = 0.7,
|
||||
personality_context: Optional[str] = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -94,6 +96,7 @@ def extract_questions_offline(
|
|||
max_prompt_size=max_prompt_size,
|
||||
temperature=temperature,
|
||||
response_type="json_object",
|
||||
tracer=tracer,
|
||||
)
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
|
@ -147,6 +150,7 @@ def converse_offline(
|
|||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||
"""
|
||||
Converse with user using Llama
|
||||
|
@ -155,6 +159,7 @@ def converse_offline(
|
|||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references})
|
||||
tracer["chat_model"] = model
|
||||
|
||||
current_date = datetime.now()
|
||||
|
||||
|
@ -218,13 +223,14 @@ def converse_offline(
|
|||
logger.debug(f"Conversation Context for {model}: {truncated_messages}")
|
||||
|
||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
|
||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
|
||||
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
|
||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||
aggregated_response = ""
|
||||
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
|
@ -232,7 +238,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
|
|||
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
||||
)
|
||||
for response in response_iterator:
|
||||
g.send(response["choices"][0]["delta"].get("content", ""))
|
||||
response_delta = response["choices"][0]["delta"].get("content", "")
|
||||
aggregated_response += response_delta
|
||||
g.send(response_delta)
|
||||
|
||||
# Save conversation trace
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
g.close()
|
||||
|
@ -247,6 +260,7 @@ def send_message_to_model_offline(
|
|||
stop=[],
|
||||
max_prompt_size: int = None,
|
||||
response_type: str = "text",
|
||||
tracer: dict = {},
|
||||
):
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
|
@ -254,7 +268,17 @@ def send_message_to_model_offline(
|
|||
response = offline_chat_model.create_chat_completion(
|
||||
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return response
|
||||
else:
|
||||
return response["choices"][0]["message"].get("content", "")
|
||||
|
||||
response_text = response["choices"][0]["message"].get("content", "")
|
||||
|
||||
# Save conversation trace for non-streaming responses
|
||||
# Streamed responses need to be saved by the calling function
|
||||
tracer["chat_model"] = model
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, response_text, tracer)
|
||||
|
||||
return response_text
|
||||
|
|
|
@ -33,6 +33,7 @@ def extract_questions(
|
|||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -82,7 +83,13 @@ def extract_questions(
|
|||
messages = [ChatMessage(content=prompt, role="user")]
|
||||
|
||||
response = send_message_to_model(
|
||||
messages, api_key, model, response_type="json_object", api_base_url=api_base_url, temperature=temperature
|
||||
messages,
|
||||
api_key,
|
||||
model,
|
||||
response_type="json_object",
|
||||
api_base_url=api_base_url,
|
||||
temperature=temperature,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
|
@ -103,7 +110,9 @@ def extract_questions(
|
|||
return questions
|
||||
|
||||
|
||||
def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None, temperature=0):
|
||||
def send_message_to_model(
|
||||
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
|
||||
):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
@ -116,6 +125,7 @@ def send_message_to_model(messages, api_key, model, response_type="text", api_ba
|
|||
temperature=temperature,
|
||||
api_base_url=api_base_url,
|
||||
model_kwargs={"response_format": {"type": response_type}},
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
|
@ -138,6 +148,7 @@ def converse(
|
|||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Converse with user using OpenAI's ChatGPT
|
||||
|
@ -214,4 +225,5 @@ def converse(
|
|||
api_base_url=api_base_url,
|
||||
completion_func=completion_func,
|
||||
model_kwargs={"stop": ["Notes:\n["]},
|
||||
tracer=tracer,
|
||||
)
|
||||
|
|
|
@ -12,7 +12,12 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -33,7 +38,7 @@ openai_clients: Dict[str, openai.OpenAI] = {}
|
|||
reraise=True,
|
||||
)
|
||||
def completion_with_backoff(
|
||||
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None
|
||||
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {}
|
||||
) -> str:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client: openai.OpenAI | None = openai_clients.get(client_key)
|
||||
|
@ -77,6 +82,12 @@ def completion_with_backoff(
|
|||
elif delta_chunk.content:
|
||||
aggregated_response += delta_chunk.content
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
return aggregated_response
|
||||
|
||||
|
||||
|
@ -103,26 +114,37 @@ def chat_completion_with_backoff(
|
|||
api_base_url=None,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs)
|
||||
target=llm_thread,
|
||||
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None):
|
||||
def llm_thread(
|
||||
g,
|
||||
messages,
|
||||
model_name,
|
||||
temperature,
|
||||
openai_api_key=None,
|
||||
api_base_url=None,
|
||||
model_kwargs=None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
try:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
if client_key not in openai_clients:
|
||||
client: openai.OpenAI = openai.OpenAI(
|
||||
client = openai.OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=api_base_url,
|
||||
)
|
||||
openai_clients[client_key] = client
|
||||
else:
|
||||
client: openai.OpenAI = openai_clients[client_key]
|
||||
client = openai_clients[client_key]
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
stream = True
|
||||
|
@ -144,17 +166,29 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
|
|||
**(model_kwargs or dict()),
|
||||
)
|
||||
|
||||
aggregated_response = ""
|
||||
if not stream:
|
||||
g.send(chat.choices[0].message.content)
|
||||
aggregated_response = chat.choices[0].message.content
|
||||
g.send(aggregated_response)
|
||||
else:
|
||||
for chunk in chat:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta_chunk = chunk.choices[0].delta
|
||||
text_chunk = ""
|
||||
if isinstance(delta_chunk, str):
|
||||
g.send(delta_chunk)
|
||||
text_chunk = delta_chunk
|
||||
elif delta_chunk.content:
|
||||
g.send(delta_chunk.content)
|
||||
text_chunk = delta_chunk.content
|
||||
if text_chunk:
|
||||
aggregated_response += text_chunk
|
||||
g.send(text_chunk)
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in llm_thread: {e}", exc_info=True)
|
||||
finally:
|
||||
|
|
|
@ -2,6 +2,7 @@ import base64
|
|||
import logging
|
||||
import math
|
||||
import mimetypes
|
||||
import os
|
||||
import queue
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
@ -13,6 +14,8 @@ from typing import Any, Dict, List, Optional
|
|||
import PIL.Image
|
||||
import requests
|
||||
import tiktoken
|
||||
import yaml
|
||||
from git import Repo
|
||||
from langchain.schema import ChatMessage
|
||||
from llama_cpp.llama import Llama
|
||||
from transformers import AutoTokenizer
|
||||
|
@ -24,7 +27,7 @@ from khoj.search_filter.date_filter import DateFilter
|
|||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
model_to_prompt_size = {
|
||||
|
@ -178,6 +181,7 @@ def save_to_conversation_log(
|
|||
conversation_id: str = None,
|
||||
automation_id: str = None,
|
||||
query_images: List[str] = None,
|
||||
tracer: Dict[str, Any] = {},
|
||||
):
|
||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
updated_conversation = message_to_log(
|
||||
|
@ -204,6 +208,9 @@ def save_to_conversation_log(
|
|||
user_message=q,
|
||||
)
|
||||
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
merge_message_into_conversation_trace(q, chat_response, tracer)
|
||||
|
||||
logger.info(
|
||||
f"""
|
||||
Saved Conversation Turn
|
||||
|
@ -415,3 +422,160 @@ def get_image_from_url(image_url: str, type="pil"):
|
|||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
||||
return ImageWithType(content=None, type=None)
|
||||
|
||||
|
||||
def commit_conversation_trace(
|
||||
session: list[ChatMessage],
|
||||
response: str | list[dict],
|
||||
tracer: dict,
|
||||
system_message: str | list[dict] = "",
|
||||
repo_path: str = "/tmp/khoj_promptrace",
|
||||
) -> str:
|
||||
"""
|
||||
Save trace of conversation step using git. Useful to visualize, compare and debug traces.
|
||||
Returns the path to the repository.
|
||||
"""
|
||||
# Serialize session, system message and response to yaml
|
||||
system_message_yaml = yaml.dump(system_message, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
response_yaml = yaml.dump(response, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
formatted_session = [{"role": message.role, "content": message.content} for message in session]
|
||||
session_yaml = yaml.dump(formatted_session, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
query = (
|
||||
yaml.dump(session[-1].content, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
.strip()
|
||||
.removeprefix("'")
|
||||
.removesuffix("'")
|
||||
) # Extract serialized query from chat session
|
||||
|
||||
# Extract chat metadata for session
|
||||
uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid")
|
||||
|
||||
# Infer repository path from environment variable or provided path
|
||||
repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace"
|
||||
|
||||
try:
|
||||
# Prepare git repository
|
||||
os.makedirs(repo_path, exist_ok=True)
|
||||
repo = Repo.init(repo_path)
|
||||
|
||||
# Remove post-commit hook if it exists
|
||||
hooks_dir = os.path.join(repo_path, ".git", "hooks")
|
||||
post_commit_hook = os.path.join(hooks_dir, "post-commit")
|
||||
if os.path.exists(post_commit_hook):
|
||||
os.remove(post_commit_hook)
|
||||
|
||||
# Configure git user if not set
|
||||
if not repo.config_reader().has_option("user", "email"):
|
||||
repo.config_writer().set_value("user", "name", "Prompt Tracer").release()
|
||||
repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release()
|
||||
|
||||
# Create an initial commit if the repository is newly created
|
||||
if not repo.head.is_valid():
|
||||
repo.index.commit("And then there was a trace")
|
||||
|
||||
# Check out the initial commit
|
||||
initial_commit = repo.commit("HEAD~0")
|
||||
repo.head.reference = initial_commit
|
||||
repo.head.reset(index=True, working_tree=True)
|
||||
|
||||
# Create or switch to user branch from initial commit
|
||||
user_branch = f"u_{uid}"
|
||||
if user_branch not in repo.branches:
|
||||
repo.create_head(user_branch)
|
||||
repo.heads[user_branch].checkout()
|
||||
|
||||
# Create or switch to conversation branch from user branch
|
||||
conv_branch = f"c_{cid}"
|
||||
if conv_branch not in repo.branches:
|
||||
repo.create_head(conv_branch)
|
||||
repo.heads[conv_branch].checkout()
|
||||
|
||||
# Create or switch to message branch from conversation branch
|
||||
msg_branch = f"m_{mid}" if mid else None
|
||||
if msg_branch and msg_branch not in repo.branches:
|
||||
repo.create_head(msg_branch)
|
||||
repo.heads[msg_branch].checkout()
|
||||
|
||||
# Include file with content to commit
|
||||
files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml}
|
||||
|
||||
# Write files and stage them
|
||||
for filename, content in files_to_commit.items():
|
||||
file_path = os.path.join(repo_path, filename)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
repo.index.add([filename])
|
||||
|
||||
# Create commit
|
||||
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
commit_message = f"""
|
||||
{query[:250]}
|
||||
|
||||
Response:
|
||||
---
|
||||
{response[:500]}...
|
||||
|
||||
Metadata
|
||||
---
|
||||
{metadata_yaml}
|
||||
""".strip()
|
||||
|
||||
repo.index.commit(commit_message)
|
||||
|
||||
logger.debug(f"Saved conversation trace to repo at {repo_path}")
|
||||
return repo_path
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add conversation trace to repo: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path=None) -> bool:
|
||||
"""
|
||||
Merge the message branch into its parent conversation branch.
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
response: Assistant response
|
||||
tracer: Dictionary containing uid, cid and mid
|
||||
repo_path: Path to the git repository
|
||||
|
||||
Returns:
|
||||
bool: True if merge was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Infer repository path from environment variable or provided path
|
||||
repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace"
|
||||
repo = Repo(repo_path)
|
||||
|
||||
# Extract branch names
|
||||
msg_branch = f"m_{tracer['mid']}"
|
||||
conv_branch = f"c_{tracer['cid']}"
|
||||
|
||||
# Checkout conversation branch
|
||||
repo.heads[conv_branch].checkout()
|
||||
|
||||
# Create commit message
|
||||
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
commit_message = f"""
|
||||
{query[:250]}
|
||||
|
||||
Response:
|
||||
---
|
||||
{response[:500]}...
|
||||
|
||||
Metadata
|
||||
---
|
||||
{metadata_yaml}
|
||||
""".strip()
|
||||
|
||||
# Merge message branch into conversation branch
|
||||
repo.git.merge(msg_branch, no_ff=True, m=commit_message)
|
||||
|
||||
# Delete message branch after merge
|
||||
repo.delete_head(msg_branch, force=True)
|
||||
|
||||
logger.debug(f"Successfully merged {msg_branch} into {conv_branch}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}")
|
||||
return False
|
||||
|
|
|
@ -28,6 +28,7 @@ async def text_to_image(
|
|||
send_status_func: Optional[Callable] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
status_code = 200
|
||||
image = None
|
||||
|
@ -68,6 +69,7 @@ async def text_to_image(
|
|||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
if send_status_func:
|
||||
|
|
|
@ -66,6 +66,7 @@ async def search_online(
|
|||
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
query += " ".join(custom_filters)
|
||||
if not is_internet_connected():
|
||||
|
@ -75,7 +76,7 @@ async def search_online(
|
|||
|
||||
# Breakdown the query into subqueries to get the correct answer
|
||||
subqueries = await generate_online_subqueries(
|
||||
query, conversation_history, location, user, query_images=query_images, agent=agent
|
||||
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
|
||||
)
|
||||
response_dict = {}
|
||||
|
||||
|
@ -113,7 +114,7 @@ async def search_online(
|
|||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [
|
||||
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent)
|
||||
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
|
||||
for link, data in webpages.items()
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
@ -155,6 +156,7 @@ async def read_webpages(
|
|||
send_status_func: Optional[Callable] = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"Infer web pages to read from the query and extract relevant information from them"
|
||||
logger.info(f"Inferring web pages to read")
|
||||
|
@ -168,7 +170,7 @@ async def read_webpages(
|
|||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls]
|
||||
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
response: Dict[str, Dict] = defaultdict(dict)
|
||||
|
@ -194,7 +196,12 @@ async def read_webpage(
|
|||
|
||||
|
||||
async def read_webpage_and_extract_content(
|
||||
subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None
|
||||
subqueries: set[str],
|
||||
url: str,
|
||||
content: str = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> Tuple[set[str], str, Union[None, str]]:
|
||||
# Select the web scrapers to use for reading the web page
|
||||
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
|
||||
|
@ -216,7 +223,9 @@ async def read_webpage_and_extract_content(
|
|||
# Extract relevant information from the web page
|
||||
if is_none_or_empty(extracted_info):
|
||||
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
||||
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent)
|
||||
extracted_info = await extract_relevant_info(
|
||||
subqueries, content, user=user, agent=agent, tracer=tracer
|
||||
)
|
||||
|
||||
# If we successfully extracted information, break the loop
|
||||
if not is_none_or_empty(extracted_info):
|
||||
|
|
|
@ -35,6 +35,7 @@ async def run_code(
|
|||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
sandbox_url: str = SANDBOX_URL,
|
||||
tracer: dict = {},
|
||||
):
|
||||
# Generate Code
|
||||
if send_status_func:
|
||||
|
@ -43,7 +44,14 @@ async def run_code(
|
|||
try:
|
||||
with timer("Chat actor: Generate programs to execute", logger):
|
||||
codes = await generate_python_code(
|
||||
query, conversation_history, previous_iterations_history, location_data, user, query_images, agent
|
||||
query,
|
||||
conversation_history,
|
||||
previous_iterations_history,
|
||||
location_data,
|
||||
user,
|
||||
query_images,
|
||||
agent,
|
||||
tracer,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
||||
|
@ -72,6 +80,7 @@ async def generate_python_code(
|
|||
user: KhojUser,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||
|
@ -98,6 +107,7 @@ async def generate_python_code(
|
|||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
|
|
|
@ -351,6 +351,7 @@ async def extract_references_and_questions(
|
|||
send_status_func: Optional[Callable] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
||||
|
@ -424,6 +425,7 @@ async def extract_references_and_questions(
|
|||
user=user,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
personality_context=personality_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = conversation_config.openai_config
|
||||
|
@ -441,6 +443,7 @@ async def extract_references_and_questions(
|
|||
query_images=query_images,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -455,6 +458,7 @@ async def extract_references_and_questions(
|
|||
user=user,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -470,6 +474,7 @@ async def extract_references_and_questions(
|
|||
user=user,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
|
|
|
@ -3,6 +3,7 @@ import base64
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
@ -570,6 +571,12 @@ async def chat(
|
|||
event_delimiter = "␃🔚␗"
|
||||
q = unquote(q)
|
||||
nonlocal conversation_id
|
||||
tracer: dict = {
|
||||
"mid": f"{uuid.uuid4()}",
|
||||
"cid": conversation_id,
|
||||
"uid": user.id,
|
||||
"khoj_version": state.khoj_version,
|
||||
}
|
||||
|
||||
uploaded_images: list[str] = []
|
||||
if images:
|
||||
|
@ -703,6 +710,7 @@ async def chat(
|
|||
user_name=user_name,
|
||||
location=location,
|
||||
file_filters=conversation.file_filters if conversation else [],
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(research_result, InformationCollectionIteration):
|
||||
if research_result.summarizedResult:
|
||||
|
@ -732,9 +740,12 @@ async def chat(
|
|||
user=user,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
|
||||
mode = await aget_relevant_output_modes(
|
||||
q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer
|
||||
)
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||
yield result
|
||||
if mode not in conversation_commands:
|
||||
|
@ -778,6 +789,7 @@ async def chat(
|
|||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(response, dict) and ChatEvent.STATUS in response:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -796,6 +808,7 @@ async def chat(
|
|||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -817,7 +830,7 @@ async def chat(
|
|||
if ConversationCommand.Automation in conversation_commands:
|
||||
try:
|
||||
automation, crontime, query_to_run, subject = await create_automation(
|
||||
q, timezone, user, request.url, meta_log
|
||||
q, timezone, user, request.url, meta_log, tracer=tracer
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||
|
@ -839,6 +852,7 @@ async def chat(
|
|||
inferred_queries=[query_to_run],
|
||||
automation_id=automation.id,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
async for result in send_llm_response(llm_response):
|
||||
yield result
|
||||
|
@ -860,6 +874,7 @@ async def chat(
|
|||
partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -905,6 +920,7 @@ async def chat(
|
|||
custom_filters,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -930,6 +946,7 @@ async def chat(
|
|||
partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -984,6 +1001,7 @@ async def chat(
|
|||
partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -1010,6 +1028,7 @@ async def chat(
|
|||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -1040,6 +1059,7 @@ async def chat(
|
|||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
content_obj = {
|
||||
"intentType": intent_type,
|
||||
|
@ -1068,6 +1088,7 @@ async def chat(
|
|||
user=user,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -1095,6 +1116,7 @@ async def chat(
|
|||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
|
@ -1120,6 +1142,7 @@ async def chat(
|
|||
user_name,
|
||||
researched_results,
|
||||
uploaded_images,
|
||||
tracer,
|
||||
)
|
||||
|
||||
# Send Response
|
||||
|
|
|
@ -291,6 +291,7 @@ async def aget_relevant_information_sources(
|
|||
user: KhojUser,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||
|
@ -327,6 +328,7 @@ async def aget_relevant_information_sources(
|
|||
relevant_tools_prompt,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -368,6 +370,7 @@ async def aget_relevant_output_modes(
|
|||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||
|
@ -403,7 +406,9 @@ async def aget_relevant_output_modes(
|
|||
)
|
||||
|
||||
with timer("Chat actor: Infer output mode for chat response", logger):
|
||||
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user)
|
||||
response = await send_message_to_model_wrapper(
|
||||
relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer
|
||||
)
|
||||
|
||||
try:
|
||||
response = response.strip()
|
||||
|
@ -434,6 +439,7 @@ async def infer_webpage_urls(
|
|||
user: KhojUser,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
"""
|
||||
Infer webpage links from the given query
|
||||
|
@ -458,7 +464,11 @@ async def infer_webpage_urls(
|
|||
|
||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
online_queries_prompt,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
||||
|
@ -481,6 +491,7 @@ async def generate_online_subqueries(
|
|||
user: KhojUser,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate subqueries from the given query
|
||||
|
@ -505,7 +516,11 @@ async def generate_online_subqueries(
|
|||
|
||||
with timer("Chat actor: Generate online search subqueries", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
online_queries_prompt,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
|
@ -524,7 +539,7 @@ async def generate_online_subqueries(
|
|||
|
||||
|
||||
async def schedule_query(
|
||||
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
|
||||
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
|
||||
) -> Tuple[str, ...]:
|
||||
"""
|
||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||
|
@ -537,7 +552,7 @@ async def schedule_query(
|
|||
)
|
||||
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
crontime_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
|
@ -552,7 +567,7 @@ async def schedule_query(
|
|||
|
||||
|
||||
async def extract_relevant_info(
|
||||
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None
|
||||
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
|
||||
) -> Union[str, None]:
|
||||
"""
|
||||
Extract relevant information for a given query from the target corpus
|
||||
|
@ -575,6 +590,7 @@ async def extract_relevant_info(
|
|||
extract_relevant_information,
|
||||
prompts.system_prompt_extract_relevant_information,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
|
@ -586,6 +602,7 @@ async def extract_relevant_summary(
|
|||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> Union[str, None]:
|
||||
"""
|
||||
Extract relevant information for a given query from the target corpus
|
||||
|
@ -613,6 +630,7 @@ async def extract_relevant_summary(
|
|||
prompts.system_prompt_extract_relevant_summary,
|
||||
user=user,
|
||||
query_images=query_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
|
@ -625,6 +643,7 @@ async def generate_summary_from_files(
|
|||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
try:
|
||||
file_object = None
|
||||
|
@ -653,6 +672,7 @@ async def generate_summary_from_files(
|
|||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
response_log = str(response)
|
||||
|
||||
|
@ -673,6 +693,7 @@ async def generate_excalidraw_diagram(
|
|||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
if send_status_func:
|
||||
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
|
||||
|
@ -687,6 +708,7 @@ async def generate_excalidraw_diagram(
|
|||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
if send_status_func:
|
||||
|
@ -697,6 +719,7 @@ async def generate_excalidraw_diagram(
|
|||
q=better_diagram_description_prompt,
|
||||
user=user,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
yield better_diagram_description_prompt, excalidraw_diagram_description
|
||||
|
@ -711,6 +734,7 @@ async def generate_better_diagram_description(
|
|||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
"""
|
||||
Generate a diagram description from the given query and context
|
||||
|
@ -748,7 +772,7 @@ async def generate_better_diagram_description(
|
|||
|
||||
with timer("Chat actor: Generate better diagram description", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
improve_diagram_description_prompt, query_images=query_images, user=user
|
||||
improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer
|
||||
)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
|
@ -761,6 +785,7 @@ async def generate_excalidraw_diagram_from_description(
|
|||
q: str,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
|
@ -772,7 +797,9 @@ async def generate_excalidraw_diagram_from_description(
|
|||
)
|
||||
|
||||
with timer("Chat actor: Generate excalidraw diagram", logger):
|
||||
raw_response = await send_message_to_model_wrapper(message=excalidraw_diagram_generation, user=user)
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
message=excalidraw_diagram_generation, user=user, tracer=tracer
|
||||
)
|
||||
raw_response = raw_response.strip()
|
||||
raw_response = remove_json_codeblock(raw_response)
|
||||
response: Dict[str, str] = json.loads(raw_response)
|
||||
|
@ -793,6 +820,7 @@ async def generate_better_image_prompt(
|
|||
query_images: Optional[List[str]] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
"""
|
||||
Generate a better image prompt from the given query
|
||||
|
@ -839,7 +867,9 @@ async def generate_better_image_prompt(
|
|||
)
|
||||
|
||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
|
||||
response = await send_message_to_model_wrapper(
|
||||
image_prompt, query_images=query_images, user=user, tracer=tracer
|
||||
)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
response = response[1:-1]
|
||||
|
@ -853,6 +883,7 @@ async def send_message_to_model_wrapper(
|
|||
response_type: str = "text",
|
||||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
||||
vision_available = conversation_config.vision_enabled
|
||||
|
@ -899,6 +930,7 @@ async def send_message_to_model_wrapper(
|
|||
max_prompt_size=max_tokens,
|
||||
streaming=False,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
|
@ -922,6 +954,7 @@ async def send_message_to_model_wrapper(
|
|||
model=chat_model,
|
||||
response_type=response_type,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -940,6 +973,7 @@ async def send_message_to_model_wrapper(
|
|||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -955,7 +989,7 @@ async def send_message_to_model_wrapper(
|
|||
)
|
||||
|
||||
return gemini_send_message_to_model(
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||
|
@ -966,6 +1000,7 @@ def send_message_to_model_wrapper_sync(
|
|||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
user: KhojUser = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
|
||||
|
||||
|
@ -998,6 +1033,7 @@ def send_message_to_model_wrapper_sync(
|
|||
max_prompt_size=max_tokens,
|
||||
streaming=False,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
|
@ -1012,7 +1048,11 @@ def send_message_to_model_wrapper_sync(
|
|||
)
|
||||
|
||||
openai_response = send_message_to_model(
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
return openai_response
|
||||
|
@ -1032,6 +1072,7 @@ def send_message_to_model_wrapper_sync(
|
|||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
|
@ -1050,6 +1091,7 @@ def send_message_to_model_wrapper_sync(
|
|||
api_key=api_key,
|
||||
model=chat_model,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||
|
@ -1071,6 +1113,7 @@ def generate_chat_response(
|
|||
user_name: Optional[str] = None,
|
||||
meta_research: str = "",
|
||||
query_images: Optional[List[str]] = None,
|
||||
tracer: dict = {},
|
||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||
# Initialize Variables
|
||||
chat_response = None
|
||||
|
@ -1094,6 +1137,7 @@ def generate_chat_response(
|
|||
client_application=client_application,
|
||||
conversation_id=conversation_id,
|
||||
query_images=query_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||
|
@ -1120,6 +1164,7 @@ def generate_chat_response(
|
|||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
|
@ -1144,6 +1189,7 @@ def generate_chat_response(
|
|||
user_name=user_name,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
|
@ -1165,6 +1211,7 @@ def generate_chat_response(
|
|||
user_name=user_name,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -1184,6 +1231,7 @@ def generate_chat_response(
|
|||
user_name=user_name,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
metadata.update({"chat_model": conversation_config.chat_model})
|
||||
|
@ -1540,9 +1588,15 @@ def scheduled_chat(
|
|||
|
||||
|
||||
async def create_automation(
|
||||
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None
|
||||
q: str,
|
||||
timezone: str,
|
||||
user: KhojUser,
|
||||
calling_url: URL,
|
||||
meta_log: dict = {},
|
||||
conversation_id: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
crontime, query_to_run, subject = await schedule_query(q, meta_log, user)
|
||||
crontime, query_to_run, subject = await schedule_query(q, meta_log, user, tracer=tracer)
|
||||
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
|
||||
return job, crontime, query_to_run, subject
|
||||
|
||||
|
|
|
@ -45,6 +45,7 @@ async def apick_next_tool(
|
|||
previous_iterations_history: str = None,
|
||||
max_iterations: int = 5,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
|
||||
|
@ -93,6 +94,7 @@ async def apick_next_tool(
|
|||
response_type="json_object",
|
||||
user=user,
|
||||
query_images=query_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -135,6 +137,7 @@ async def execute_information_collection(
|
|||
user_name: str = None,
|
||||
location: LocationData = None,
|
||||
file_filters: List[str] = [],
|
||||
tracer: dict = {},
|
||||
):
|
||||
current_iteration = 0
|
||||
MAX_ITERATIONS = 5
|
||||
|
@ -159,6 +162,7 @@ async def execute_information_collection(
|
|||
previous_iterations_history,
|
||||
MAX_ITERATIONS,
|
||||
send_status_func,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -180,6 +184,7 @@ async def execute_information_collection(
|
|||
send_status_func,
|
||||
query_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -211,6 +216,7 @@ async def execute_information_collection(
|
|||
max_webpages_to_read=0,
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -228,6 +234,7 @@ async def execute_information_collection(
|
|||
send_status_func,
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -258,6 +265,7 @@ async def execute_information_collection(
|
|||
send_status_func,
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
|
Loading…
Reference in a new issue