mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Merge branch 'add-prompt-tracer-for-observability' of github.com:khoj-ai/khoj into features/advanced-reasoning
- Start from this branches src/khoj/routers/api_chat.py Add tracer to all old and new chat actors that don't have it set when they are called. - Update the new chat actors like apick next tool etc to use tracer too
This commit is contained in:
commit
f04f871a72
16 changed files with 458 additions and 57 deletions
|
@ -119,6 +119,7 @@ dev = [
|
||||||
"mypy >= 1.0.1",
|
"mypy >= 1.0.1",
|
||||||
"black >= 23.1.0",
|
"black >= 23.1.0",
|
||||||
"pre-commit >= 3.0.4",
|
"pre-commit >= 3.0.4",
|
||||||
|
"gitpython ~= 3.1.43",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.hatch.version]
|
[tool.hatch.version]
|
||||||
|
|
|
@ -34,6 +34,7 @@ def extract_questions_anthropic(
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
|
@ -89,6 +90,7 @@ def extract_questions_anthropic(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from Claude's Response
|
# Extract, Clean Message from Claude's Response
|
||||||
|
@ -110,7 +112,7 @@ def extract_questions_anthropic(
|
||||||
return questions
|
return questions
|
||||||
|
|
||||||
|
|
||||||
def anthropic_send_message_to_model(messages, api_key, model):
|
def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
|
||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
"""
|
"""
|
||||||
|
@ -122,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model):
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -142,6 +145,7 @@ def converse_anthropic(
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Converse with user using Anthropic's Claude
|
Converse with user using Anthropic's Claude
|
||||||
|
@ -220,4 +224,5 @@ def converse_anthropic(
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
completion_func=completion_func,
|
completion_func=completion_func,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
|
@ -12,8 +12,13 @@ from tenacity import (
|
||||||
wait_random_exponential,
|
wait_random_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
|
from khoj.processor.conversation.utils import (
|
||||||
from khoj.utils.helpers import is_none_or_empty
|
ThreadedGenerator,
|
||||||
|
commit_conversation_trace,
|
||||||
|
get_image_from_url,
|
||||||
|
)
|
||||||
|
from khoj.utils import state
|
||||||
|
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -30,7 +35,7 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def anthropic_completion_with_backoff(
|
def anthropic_completion_with_backoff(
|
||||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
|
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={}
|
||||||
) -> str:
|
) -> str:
|
||||||
if api_key not in anthropic_clients:
|
if api_key not in anthropic_clients:
|
||||||
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
||||||
|
@ -58,6 +63,12 @@ def anthropic_completion_with_backoff(
|
||||||
for text in stream.text_stream:
|
for text in stream.text_stream:
|
||||||
aggregated_response += text
|
aggregated_response += text
|
||||||
|
|
||||||
|
# Save conversation trace
|
||||||
|
tracer["chat_model"] = model_name
|
||||||
|
tracer["temperature"] = temperature
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
|
|
||||||
return aggregated_response
|
return aggregated_response
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,18 +89,19 @@ def anthropic_chat_completion_with_backoff(
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
|
tracer={},
|
||||||
):
|
):
|
||||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||||
t = Thread(
|
t = Thread(
|
||||||
target=anthropic_llm_thread,
|
target=anthropic_llm_thread,
|
||||||
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
|
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs, tracer),
|
||||||
)
|
)
|
||||||
t.start()
|
t.start()
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
def anthropic_llm_thread(
|
def anthropic_llm_thread(
|
||||||
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
|
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None, tracer={}
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if api_key not in anthropic_clients:
|
if api_key not in anthropic_clients:
|
||||||
|
@ -102,6 +114,7 @@ def anthropic_llm_thread(
|
||||||
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
|
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
|
||||||
]
|
]
|
||||||
|
|
||||||
|
aggregated_response = ""
|
||||||
with client.messages.stream(
|
with client.messages.stream(
|
||||||
messages=formatted_messages,
|
messages=formatted_messages,
|
||||||
model=model_name, # type: ignore
|
model=model_name, # type: ignore
|
||||||
|
@ -112,7 +125,14 @@ def anthropic_llm_thread(
|
||||||
**(model_kwargs or dict()),
|
**(model_kwargs or dict()),
|
||||||
) as stream:
|
) as stream:
|
||||||
for text in stream.text_stream:
|
for text in stream.text_stream:
|
||||||
|
aggregated_response += text
|
||||||
g.send(text)
|
g.send(text)
|
||||||
|
|
||||||
|
# Save conversation trace
|
||||||
|
tracer["chat_model"] = model_name
|
||||||
|
tracer["temperature"] = temperature
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
|
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -35,6 +35,7 @@ def extract_questions_gemini(
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
|
@ -85,7 +86,7 @@ def extract_questions_gemini(
|
||||||
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
|
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
|
||||||
|
|
||||||
response = gemini_send_message_to_model(
|
response = gemini_send_message_to_model(
|
||||||
messages, api_key, model, response_type="json_object", temperature=temperature
|
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from Gemini's Response
|
# Extract, Clean Message from Gemini's Response
|
||||||
|
@ -107,7 +108,9 @@ def extract_questions_gemini(
|
||||||
return questions
|
return questions
|
||||||
|
|
||||||
|
|
||||||
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
|
def gemini_send_message_to_model(
|
||||||
|
messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={}
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
"""
|
"""
|
||||||
|
@ -125,6 +128,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -146,6 +150,7 @@ def converse_gemini(
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
|
tracer={},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Converse with user using Google's Gemini
|
Converse with user using Google's Gemini
|
||||||
|
@ -224,4 +229,5 @@ def converse_gemini(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
completion_func=completion_func,
|
completion_func=completion_func,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,8 +19,13 @@ from tenacity import (
|
||||||
wait_random_exponential,
|
wait_random_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
|
from khoj.processor.conversation.utils import (
|
||||||
from khoj.utils.helpers import is_none_or_empty
|
ThreadedGenerator,
|
||||||
|
commit_conversation_trace,
|
||||||
|
get_image_from_url,
|
||||||
|
)
|
||||||
|
from khoj.utils import state
|
||||||
|
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -35,7 +40,7 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def gemini_completion_with_backoff(
|
def gemini_completion_with_backoff(
|
||||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None
|
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
|
||||||
) -> str:
|
) -> str:
|
||||||
genai.configure(api_key=api_key)
|
genai.configure(api_key=api_key)
|
||||||
model_kwargs = model_kwargs or dict()
|
model_kwargs = model_kwargs or dict()
|
||||||
|
@ -60,16 +65,23 @@ def gemini_completion_with_backoff(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate the response. The last message is considered to be the current prompt
|
# Generate the response. The last message is considered to be the current prompt
|
||||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
|
response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||||
return aggregated_response.text
|
response_text = response.text
|
||||||
except StopCandidateException as e:
|
except StopCandidateException as e:
|
||||||
response_message, _ = handle_gemini_response(e.args)
|
response_text, _ = handle_gemini_response(e.args)
|
||||||
# Respond with reason for stopping
|
# Respond with reason for stopping
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"LLM Response Prevented for {model_name}: {response_message}.\n"
|
f"LLM Response Prevented for {model_name}: {response_text}.\n"
|
||||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||||
)
|
)
|
||||||
return response_message
|
|
||||||
|
# Save conversation trace
|
||||||
|
tracer["chat_model"] = model_name
|
||||||
|
tracer["temperature"] = temperature
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
commit_conversation_trace(messages, response_text, tracer)
|
||||||
|
|
||||||
|
return response_text
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
|
@ -88,17 +100,20 @@ def gemini_chat_completion_with_backoff(
|
||||||
system_prompt,
|
system_prompt,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||||
t = Thread(
|
t = Thread(
|
||||||
target=gemini_llm_thread,
|
target=gemini_llm_thread,
|
||||||
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs),
|
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
|
||||||
)
|
)
|
||||||
t.start()
|
t.start()
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None):
|
def gemini_llm_thread(
|
||||||
|
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
genai.configure(api_key=api_key)
|
genai.configure(api_key=api_key)
|
||||||
model_kwargs = model_kwargs or dict()
|
model_kwargs = model_kwargs or dict()
|
||||||
|
@ -117,16 +132,25 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
aggregated_response = ""
|
||||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||||
|
|
||||||
# all messages up to the last are considered to be part of the chat history
|
# all messages up to the last are considered to be part of the chat history
|
||||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||||
# the last message is considered to be the current prompt
|
# the last message is considered to be the current prompt
|
||||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
|
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
|
||||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||||
message = message or chunk.text
|
message = message or chunk.text
|
||||||
|
aggregated_response += message
|
||||||
g.send(message)
|
g.send(message)
|
||||||
if stopped:
|
if stopped:
|
||||||
raise StopCandidateException(message)
|
raise StopCandidateException(message)
|
||||||
|
|
||||||
|
# Save conversation trace
|
||||||
|
tracer["chat_model"] = model_name
|
||||||
|
tracer["temperature"] = temperature
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
except StopCandidateException as e:
|
except StopCandidateException as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
|
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
|
||||||
|
|
|
@ -12,11 +12,12 @@ from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.offline.utils import download_model
|
from khoj.processor.conversation.offline.utils import download_model
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ThreadedGenerator,
|
ThreadedGenerator,
|
||||||
|
commit_conversation_trace,
|
||||||
generate_chatml_messages_with_context,
|
generate_chatml_messages_with_context,
|
||||||
)
|
)
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.constants import empty_escape_sequences
|
from khoj.utils.constants import empty_escape_sequences
|
||||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty
|
||||||
from khoj.utils.rawconfig import LocationData
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -34,6 +35,7 @@ def extract_questions_offline(
|
||||||
max_prompt_size: int = None,
|
max_prompt_size: int = None,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
|
@ -94,6 +96,7 @@ def extract_questions_offline(
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
state.chat_lock.release()
|
state.chat_lock.release()
|
||||||
|
@ -147,6 +150,7 @@ def converse_offline(
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||||
"""
|
"""
|
||||||
Converse with user using Llama
|
Converse with user using Llama
|
||||||
|
@ -155,6 +159,7 @@ def converse_offline(
|
||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||||
compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references})
|
compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references})
|
||||||
|
tracer["chat_model"] = model
|
||||||
|
|
||||||
current_date = datetime.now()
|
current_date = datetime.now()
|
||||||
|
|
||||||
|
@ -218,13 +223,14 @@ def converse_offline(
|
||||||
logger.debug(f"Conversation Context for {model}: {truncated_messages}")
|
logger.debug(f"Conversation Context for {model}: {truncated_messages}")
|
||||||
|
|
||||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
||||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
|
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
|
||||||
t.start()
|
t.start()
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
|
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
|
||||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||||
|
aggregated_response = ""
|
||||||
|
|
||||||
state.chat_lock.acquire()
|
state.chat_lock.acquire()
|
||||||
try:
|
try:
|
||||||
|
@ -232,7 +238,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
|
||||||
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
||||||
)
|
)
|
||||||
for response in response_iterator:
|
for response in response_iterator:
|
||||||
g.send(response["choices"][0]["delta"].get("content", ""))
|
response_delta = response["choices"][0]["delta"].get("content", "")
|
||||||
|
aggregated_response += response_delta
|
||||||
|
g.send(response_delta)
|
||||||
|
|
||||||
|
# Save conversation trace
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
state.chat_lock.release()
|
state.chat_lock.release()
|
||||||
g.close()
|
g.close()
|
||||||
|
@ -247,6 +260,7 @@ def send_message_to_model_offline(
|
||||||
stop=[],
|
stop=[],
|
||||||
max_prompt_size: int = None,
|
max_prompt_size: int = None,
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||||
|
@ -254,7 +268,17 @@ def send_message_to_model_offline(
|
||||||
response = offline_chat_model.create_chat_completion(
|
response = offline_chat_model.create_chat_completion(
|
||||||
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
|
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
|
||||||
)
|
)
|
||||||
|
|
||||||
if streaming:
|
if streaming:
|
||||||
return response
|
return response
|
||||||
else:
|
|
||||||
return response["choices"][0]["message"].get("content", "")
|
response_text = response["choices"][0]["message"].get("content", "")
|
||||||
|
|
||||||
|
# Save conversation trace for non-streaming responses
|
||||||
|
# Streamed responses need to be saved by the calling function
|
||||||
|
tracer["chat_model"] = model
|
||||||
|
tracer["temperature"] = temperature
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
commit_conversation_trace(messages, response_text, tracer)
|
||||||
|
|
||||||
|
return response_text
|
||||||
|
|
|
@ -33,6 +33,7 @@ def extract_questions(
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
|
@ -82,7 +83,13 @@ def extract_questions(
|
||||||
messages = [ChatMessage(content=prompt, role="user")]
|
messages = [ChatMessage(content=prompt, role="user")]
|
||||||
|
|
||||||
response = send_message_to_model(
|
response = send_message_to_model(
|
||||||
messages, api_key, model, response_type="json_object", api_base_url=api_base_url, temperature=temperature
|
messages,
|
||||||
|
api_key,
|
||||||
|
model,
|
||||||
|
response_type="json_object",
|
||||||
|
api_base_url=api_base_url,
|
||||||
|
temperature=temperature,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
|
@ -103,7 +110,9 @@ def extract_questions(
|
||||||
return questions
|
return questions
|
||||||
|
|
||||||
|
|
||||||
def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None, temperature=0):
|
def send_message_to_model(
|
||||||
|
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
"""
|
"""
|
||||||
|
@ -116,6 +125,7 @@ def send_message_to_model(messages, api_key, model, response_type="text", api_ba
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
model_kwargs={"response_format": {"type": response_type}},
|
model_kwargs={"response_format": {"type": response_type}},
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -138,6 +148,7 @@ def converse(
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Converse with user using OpenAI's ChatGPT
|
Converse with user using OpenAI's ChatGPT
|
||||||
|
@ -214,4 +225,5 @@ def converse(
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
completion_func=completion_func,
|
completion_func=completion_func,
|
||||||
model_kwargs={"stop": ["Notes:\n["]},
|
model_kwargs={"stop": ["Notes:\n["]},
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
|
@ -12,7 +12,12 @@ from tenacity import (
|
||||||
wait_random_exponential,
|
wait_random_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khoj.processor.conversation.utils import ThreadedGenerator
|
from khoj.processor.conversation.utils import (
|
||||||
|
ThreadedGenerator,
|
||||||
|
commit_conversation_trace,
|
||||||
|
)
|
||||||
|
from khoj.utils import state
|
||||||
|
from khoj.utils.helpers import in_debug_mode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -33,7 +38,7 @@ openai_clients: Dict[str, openai.OpenAI] = {}
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def completion_with_backoff(
|
def completion_with_backoff(
|
||||||
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None
|
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {}
|
||||||
) -> str:
|
) -> str:
|
||||||
client_key = f"{openai_api_key}--{api_base_url}"
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
client: openai.OpenAI | None = openai_clients.get(client_key)
|
client: openai.OpenAI | None = openai_clients.get(client_key)
|
||||||
|
@ -77,6 +82,12 @@ def completion_with_backoff(
|
||||||
elif delta_chunk.content:
|
elif delta_chunk.content:
|
||||||
aggregated_response += delta_chunk.content
|
aggregated_response += delta_chunk.content
|
||||||
|
|
||||||
|
# Save conversation trace
|
||||||
|
tracer["chat_model"] = model
|
||||||
|
tracer["temperature"] = temperature
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
|
|
||||||
return aggregated_response
|
return aggregated_response
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,26 +114,37 @@ def chat_completion_with_backoff(
|
||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||||
t = Thread(
|
t = Thread(
|
||||||
target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs)
|
target=llm_thread,
|
||||||
|
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer),
|
||||||
)
|
)
|
||||||
t.start()
|
t.start()
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None):
|
def llm_thread(
|
||||||
|
g,
|
||||||
|
messages,
|
||||||
|
model_name,
|
||||||
|
temperature,
|
||||||
|
openai_api_key=None,
|
||||||
|
api_base_url=None,
|
||||||
|
model_kwargs=None,
|
||||||
|
tracer: dict = {},
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
client_key = f"{openai_api_key}--{api_base_url}"
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
if client_key not in openai_clients:
|
if client_key not in openai_clients:
|
||||||
client: openai.OpenAI = openai.OpenAI(
|
client = openai.OpenAI(
|
||||||
api_key=openai_api_key,
|
api_key=openai_api_key,
|
||||||
base_url=api_base_url,
|
base_url=api_base_url,
|
||||||
)
|
)
|
||||||
openai_clients[client_key] = client
|
openai_clients[client_key] = client
|
||||||
else:
|
else:
|
||||||
client: openai.OpenAI = openai_clients[client_key]
|
client = openai_clients[client_key]
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
stream = True
|
stream = True
|
||||||
|
@ -144,17 +166,29 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
|
||||||
**(model_kwargs or dict()),
|
**(model_kwargs or dict()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
aggregated_response = ""
|
||||||
if not stream:
|
if not stream:
|
||||||
g.send(chat.choices[0].message.content)
|
aggregated_response = chat.choices[0].message.content
|
||||||
|
g.send(aggregated_response)
|
||||||
else:
|
else:
|
||||||
for chunk in chat:
|
for chunk in chat:
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
continue
|
continue
|
||||||
delta_chunk = chunk.choices[0].delta
|
delta_chunk = chunk.choices[0].delta
|
||||||
|
text_chunk = ""
|
||||||
if isinstance(delta_chunk, str):
|
if isinstance(delta_chunk, str):
|
||||||
g.send(delta_chunk)
|
text_chunk = delta_chunk
|
||||||
elif delta_chunk.content:
|
elif delta_chunk.content:
|
||||||
g.send(delta_chunk.content)
|
text_chunk = delta_chunk.content
|
||||||
|
if text_chunk:
|
||||||
|
aggregated_response += text_chunk
|
||||||
|
g.send(text_chunk)
|
||||||
|
|
||||||
|
# Save conversation trace
|
||||||
|
tracer["chat_model"] = model_name
|
||||||
|
tracer["temperature"] = temperature
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in llm_thread: {e}", exc_info=True)
|
logger.error(f"Error in llm_thread: {e}", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -2,6 +2,7 @@ import base64
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
import os
|
||||||
import queue
|
import queue
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -13,6 +14,8 @@ from typing import Any, Dict, List, Optional
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import requests
|
import requests
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
import yaml
|
||||||
|
from git import Repo
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
from llama_cpp.llama import Llama
|
from llama_cpp.llama import Llama
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
@ -24,7 +27,7 @@ from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
model_to_prompt_size = {
|
model_to_prompt_size = {
|
||||||
|
@ -178,6 +181,7 @@ def save_to_conversation_log(
|
||||||
conversation_id: str = None,
|
conversation_id: str = None,
|
||||||
automation_id: str = None,
|
automation_id: str = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
|
tracer: Dict[str, Any] = {},
|
||||||
):
|
):
|
||||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
updated_conversation = message_to_log(
|
updated_conversation = message_to_log(
|
||||||
|
@ -204,6 +208,9 @@ def save_to_conversation_log(
|
||||||
user_message=q,
|
user_message=q,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
merge_message_into_conversation_trace(q, chat_response, tracer)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"""
|
f"""
|
||||||
Saved Conversation Turn
|
Saved Conversation Turn
|
||||||
|
@ -415,3 +422,160 @@ def get_image_from_url(image_url: str, type="pil"):
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
||||||
return ImageWithType(content=None, type=None)
|
return ImageWithType(content=None, type=None)
|
||||||
|
|
||||||
|
|
||||||
|
def commit_conversation_trace(
|
||||||
|
session: list[ChatMessage],
|
||||||
|
response: str | list[dict],
|
||||||
|
tracer: dict,
|
||||||
|
system_message: str | list[dict] = "",
|
||||||
|
repo_path: str = "/tmp/khoj_promptrace",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Save trace of conversation step using git. Useful to visualize, compare and debug traces.
|
||||||
|
Returns the path to the repository.
|
||||||
|
"""
|
||||||
|
# Serialize session, system message and response to yaml
|
||||||
|
system_message_yaml = yaml.dump(system_message, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||||
|
response_yaml = yaml.dump(response, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||||
|
formatted_session = [{"role": message.role, "content": message.content} for message in session]
|
||||||
|
session_yaml = yaml.dump(formatted_session, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||||
|
query = (
|
||||||
|
yaml.dump(session[-1].content, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||||
|
.strip()
|
||||||
|
.removeprefix("'")
|
||||||
|
.removesuffix("'")
|
||||||
|
) # Extract serialized query from chat session
|
||||||
|
|
||||||
|
# Extract chat metadata for session
|
||||||
|
uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid")
|
||||||
|
|
||||||
|
# Infer repository path from environment variable or provided path
|
||||||
|
repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare git repository
|
||||||
|
os.makedirs(repo_path, exist_ok=True)
|
||||||
|
repo = Repo.init(repo_path)
|
||||||
|
|
||||||
|
# Remove post-commit hook if it exists
|
||||||
|
hooks_dir = os.path.join(repo_path, ".git", "hooks")
|
||||||
|
post_commit_hook = os.path.join(hooks_dir, "post-commit")
|
||||||
|
if os.path.exists(post_commit_hook):
|
||||||
|
os.remove(post_commit_hook)
|
||||||
|
|
||||||
|
# Configure git user if not set
|
||||||
|
if not repo.config_reader().has_option("user", "email"):
|
||||||
|
repo.config_writer().set_value("user", "name", "Prompt Tracer").release()
|
||||||
|
repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release()
|
||||||
|
|
||||||
|
# Create an initial commit if the repository is newly created
|
||||||
|
if not repo.head.is_valid():
|
||||||
|
repo.index.commit("And then there was a trace")
|
||||||
|
|
||||||
|
# Check out the initial commit
|
||||||
|
initial_commit = repo.commit("HEAD~0")
|
||||||
|
repo.head.reference = initial_commit
|
||||||
|
repo.head.reset(index=True, working_tree=True)
|
||||||
|
|
||||||
|
# Create or switch to user branch from initial commit
|
||||||
|
user_branch = f"u_{uid}"
|
||||||
|
if user_branch not in repo.branches:
|
||||||
|
repo.create_head(user_branch)
|
||||||
|
repo.heads[user_branch].checkout()
|
||||||
|
|
||||||
|
# Create or switch to conversation branch from user branch
|
||||||
|
conv_branch = f"c_{cid}"
|
||||||
|
if conv_branch not in repo.branches:
|
||||||
|
repo.create_head(conv_branch)
|
||||||
|
repo.heads[conv_branch].checkout()
|
||||||
|
|
||||||
|
# Create or switch to message branch from conversation branch
|
||||||
|
msg_branch = f"m_{mid}" if mid else None
|
||||||
|
if msg_branch and msg_branch not in repo.branches:
|
||||||
|
repo.create_head(msg_branch)
|
||||||
|
repo.heads[msg_branch].checkout()
|
||||||
|
|
||||||
|
# Include file with content to commit
|
||||||
|
files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml}
|
||||||
|
|
||||||
|
# Write files and stage them
|
||||||
|
for filename, content in files_to_commit.items():
|
||||||
|
file_path = os.path.join(repo_path, filename)
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(content)
|
||||||
|
repo.index.add([filename])
|
||||||
|
|
||||||
|
# Create commit
|
||||||
|
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||||
|
commit_message = f"""
|
||||||
|
{query[:250]}
|
||||||
|
|
||||||
|
Response:
|
||||||
|
---
|
||||||
|
{response[:500]}...
|
||||||
|
|
||||||
|
Metadata
|
||||||
|
---
|
||||||
|
{metadata_yaml}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
repo.index.commit(commit_message)
|
||||||
|
|
||||||
|
logger.debug(f"Saved conversation trace to repo at {repo_path}")
|
||||||
|
return repo_path
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to add conversation trace to repo: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path=None) -> bool:
|
||||||
|
"""
|
||||||
|
Merge the message branch into its parent conversation branch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: User query
|
||||||
|
response: Assistant response
|
||||||
|
tracer: Dictionary containing uid, cid and mid
|
||||||
|
repo_path: Path to the git repository
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if merge was successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Infer repository path from environment variable or provided path
|
||||||
|
repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace"
|
||||||
|
repo = Repo(repo_path)
|
||||||
|
|
||||||
|
# Extract branch names
|
||||||
|
msg_branch = f"m_{tracer['mid']}"
|
||||||
|
conv_branch = f"c_{tracer['cid']}"
|
||||||
|
|
||||||
|
# Checkout conversation branch
|
||||||
|
repo.heads[conv_branch].checkout()
|
||||||
|
|
||||||
|
# Create commit message
|
||||||
|
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||||
|
commit_message = f"""
|
||||||
|
{query[:250]}
|
||||||
|
|
||||||
|
Response:
|
||||||
|
---
|
||||||
|
{response[:500]}...
|
||||||
|
|
||||||
|
Metadata
|
||||||
|
---
|
||||||
|
{metadata_yaml}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
# Merge message branch into conversation branch
|
||||||
|
repo.git.merge(msg_branch, no_ff=True, m=commit_message)
|
||||||
|
|
||||||
|
# Delete message branch after merge
|
||||||
|
repo.delete_head(msg_branch, force=True)
|
||||||
|
|
||||||
|
logger.debug(f"Successfully merged {msg_branch} into {conv_branch}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
|
@ -28,6 +28,7 @@ async def text_to_image(
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
query_images: Optional[List[str]] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
status_code = 200
|
status_code = 200
|
||||||
image = None
|
image = None
|
||||||
|
@ -68,6 +69,7 @@ async def text_to_image(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
|
|
|
@ -66,6 +66,7 @@ async def search_online(
|
||||||
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
query += " ".join(custom_filters)
|
query += " ".join(custom_filters)
|
||||||
if not is_internet_connected():
|
if not is_internet_connected():
|
||||||
|
@ -75,7 +76,7 @@ async def search_online(
|
||||||
|
|
||||||
# Breakdown the query into subqueries to get the correct answer
|
# Breakdown the query into subqueries to get the correct answer
|
||||||
subqueries = await generate_online_subqueries(
|
subqueries = await generate_online_subqueries(
|
||||||
query, conversation_history, location, user, query_images=query_images, agent=agent
|
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
|
||||||
)
|
)
|
||||||
response_dict = {}
|
response_dict = {}
|
||||||
|
|
||||||
|
@ -113,7 +114,7 @@ async def search_online(
|
||||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [
|
tasks = [
|
||||||
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent)
|
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
|
||||||
for link, data in webpages.items()
|
for link, data in webpages.items()
|
||||||
]
|
]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
@ -155,6 +156,7 @@ async def read_webpages(
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"Infer web pages to read from the query and extract relevant information from them"
|
"Infer web pages to read from the query and extract relevant information from them"
|
||||||
logger.info(f"Inferring web pages to read")
|
logger.info(f"Inferring web pages to read")
|
||||||
|
@ -168,7 +170,7 @@ async def read_webpages(
|
||||||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls]
|
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
response: Dict[str, Dict] = defaultdict(dict)
|
response: Dict[str, Dict] = defaultdict(dict)
|
||||||
|
@ -194,7 +196,12 @@ async def read_webpage(
|
||||||
|
|
||||||
|
|
||||||
async def read_webpage_and_extract_content(
|
async def read_webpage_and_extract_content(
|
||||||
subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None
|
subqueries: set[str],
|
||||||
|
url: str,
|
||||||
|
content: str = None,
|
||||||
|
user: KhojUser = None,
|
||||||
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> Tuple[set[str], str, Union[None, str]]:
|
) -> Tuple[set[str], str, Union[None, str]]:
|
||||||
# Select the web scrapers to use for reading the web page
|
# Select the web scrapers to use for reading the web page
|
||||||
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
|
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
|
||||||
|
@ -216,7 +223,9 @@ async def read_webpage_and_extract_content(
|
||||||
# Extract relevant information from the web page
|
# Extract relevant information from the web page
|
||||||
if is_none_or_empty(extracted_info):
|
if is_none_or_empty(extracted_info):
|
||||||
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
||||||
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent)
|
extracted_info = await extract_relevant_info(
|
||||||
|
subqueries, content, user=user, agent=agent, tracer=tracer
|
||||||
|
)
|
||||||
|
|
||||||
# If we successfully extracted information, break the loop
|
# If we successfully extracted information, break the loop
|
||||||
if not is_none_or_empty(extracted_info):
|
if not is_none_or_empty(extracted_info):
|
||||||
|
|
|
@ -35,6 +35,7 @@ async def run_code(
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
sandbox_url: str = SANDBOX_URL,
|
sandbox_url: str = SANDBOX_URL,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
# Generate Code
|
# Generate Code
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
|
@ -43,7 +44,14 @@ async def run_code(
|
||||||
try:
|
try:
|
||||||
with timer("Chat actor: Generate programs to execute", logger):
|
with timer("Chat actor: Generate programs to execute", logger):
|
||||||
codes = await generate_python_code(
|
codes = await generate_python_code(
|
||||||
query, conversation_history, previous_iterations_history, location_data, user, query_images, agent
|
query,
|
||||||
|
conversation_history,
|
||||||
|
previous_iterations_history,
|
||||||
|
location_data,
|
||||||
|
user,
|
||||||
|
query_images,
|
||||||
|
agent,
|
||||||
|
tracer,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
||||||
|
@ -72,6 +80,7 @@ async def generate_python_code(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
location = f"{location_data}" if location_data else "Unknown"
|
location = f"{location_data}" if location_data else "Unknown"
|
||||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||||
|
@ -98,6 +107,7 @@ async def generate_python_code(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
user=user,
|
user=user,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
|
|
|
@ -351,6 +351,7 @@ async def extract_references_and_questions(
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
query_images: Optional[List[str]] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
user = request.user.object if request.user.is_authenticated else None
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
|
|
||||||
|
@ -424,6 +425,7 @@ async def extract_references_and_questions(
|
||||||
user=user,
|
user=user,
|
||||||
max_prompt_size=conversation_config.max_prompt_size,
|
max_prompt_size=conversation_config.max_prompt_size,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
openai_chat_config = conversation_config.openai_config
|
openai_chat_config = conversation_config.openai_config
|
||||||
|
@ -441,6 +443,7 @@ async def extract_references_and_questions(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
@ -455,6 +458,7 @@ async def extract_references_and_questions(
|
||||||
user=user,
|
user=user,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
@ -470,6 +474,7 @@ async def extract_references_and_questions(
|
||||||
user=user,
|
user=user,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Collate search results as context for GPT
|
# Collate search results as context for GPT
|
||||||
|
|
|
@ -3,6 +3,7 @@ import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
@ -570,6 +571,12 @@ async def chat(
|
||||||
event_delimiter = "␃🔚␗"
|
event_delimiter = "␃🔚␗"
|
||||||
q = unquote(q)
|
q = unquote(q)
|
||||||
nonlocal conversation_id
|
nonlocal conversation_id
|
||||||
|
tracer: dict = {
|
||||||
|
"mid": f"{uuid.uuid4()}",
|
||||||
|
"cid": conversation_id,
|
||||||
|
"uid": user.id,
|
||||||
|
"khoj_version": state.khoj_version,
|
||||||
|
}
|
||||||
|
|
||||||
uploaded_images: list[str] = []
|
uploaded_images: list[str] = []
|
||||||
if images:
|
if images:
|
||||||
|
@ -703,6 +710,7 @@ async def chat(
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
location=location,
|
location=location,
|
||||||
file_filters=conversation.file_filters if conversation else [],
|
file_filters=conversation.file_filters if conversation else [],
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(research_result, InformationCollectionIteration):
|
if isinstance(research_result, InformationCollectionIteration):
|
||||||
if research_result.summarizedResult:
|
if research_result.summarizedResult:
|
||||||
|
@ -732,9 +740,12 @@ async def chat(
|
||||||
user=user,
|
user=user,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
|
mode = await aget_relevant_output_modes(
|
||||||
|
q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer
|
||||||
|
)
|
||||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||||
yield result
|
yield result
|
||||||
if mode not in conversation_commands:
|
if mode not in conversation_commands:
|
||||||
|
@ -778,6 +789,7 @@ async def chat(
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(response, dict) and ChatEvent.STATUS in response:
|
if isinstance(response, dict) and ChatEvent.STATUS in response:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -796,6 +808,7 @@ async def chat(
|
||||||
client_application=request.user.client_app,
|
client_application=request.user.client_app,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -817,7 +830,7 @@ async def chat(
|
||||||
if ConversationCommand.Automation in conversation_commands:
|
if ConversationCommand.Automation in conversation_commands:
|
||||||
try:
|
try:
|
||||||
automation, crontime, query_to_run, subject = await create_automation(
|
automation, crontime, query_to_run, subject = await create_automation(
|
||||||
q, timezone, user, request.url, meta_log
|
q, timezone, user, request.url, meta_log, tracer=tracer
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||||
|
@ -839,6 +852,7 @@ async def chat(
|
||||||
inferred_queries=[query_to_run],
|
inferred_queries=[query_to_run],
|
||||||
automation_id=automation.id,
|
automation_id=automation.id,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
async for result in send_llm_response(llm_response):
|
async for result in send_llm_response(llm_response):
|
||||||
yield result
|
yield result
|
||||||
|
@ -860,6 +874,7 @@ async def chat(
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -905,6 +920,7 @@ async def chat(
|
||||||
custom_filters,
|
custom_filters,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -930,6 +946,7 @@ async def chat(
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -984,6 +1001,7 @@ async def chat(
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -1010,6 +1028,7 @@ async def chat(
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -1040,6 +1059,7 @@ async def chat(
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
content_obj = {
|
content_obj = {
|
||||||
"intentType": intent_type,
|
"intentType": intent_type,
|
||||||
|
@ -1068,6 +1088,7 @@ async def chat(
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -1095,6 +1116,7 @@ async def chat(
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj)):
|
||||||
|
@ -1120,6 +1142,7 @@ async def chat(
|
||||||
user_name,
|
user_name,
|
||||||
researched_results,
|
researched_results,
|
||||||
uploaded_images,
|
uploaded_images,
|
||||||
|
tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send Response
|
# Send Response
|
||||||
|
|
|
@ -291,6 +291,7 @@ async def aget_relevant_information_sources(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||||
|
@ -327,6 +328,7 @@ async def aget_relevant_information_sources(
|
||||||
relevant_tools_prompt,
|
relevant_tools_prompt,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
user=user,
|
user=user,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -368,6 +370,7 @@ async def aget_relevant_output_modes(
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||||
|
@ -403,7 +406,9 @@ async def aget_relevant_output_modes(
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Chat actor: Infer output mode for chat response", logger):
|
with timer("Chat actor: Infer output mode for chat response", logger):
|
||||||
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user)
|
response = await send_message_to_model_wrapper(
|
||||||
|
relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
|
@ -434,6 +439,7 @@ async def infer_webpage_urls(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Infer webpage links from the given query
|
Infer webpage links from the given query
|
||||||
|
@ -458,7 +464,11 @@ async def infer_webpage_urls(
|
||||||
|
|
||||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
online_queries_prompt,
|
||||||
|
query_images=query_images,
|
||||||
|
response_type="json_object",
|
||||||
|
user=user,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
||||||
|
@ -481,6 +491,7 @@ async def generate_online_subqueries(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Generate subqueries from the given query
|
Generate subqueries from the given query
|
||||||
|
@ -505,7 +516,11 @@ async def generate_online_subqueries(
|
||||||
|
|
||||||
with timer("Chat actor: Generate online search subqueries", logger):
|
with timer("Chat actor: Generate online search subqueries", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
online_queries_prompt,
|
||||||
|
query_images=query_images,
|
||||||
|
response_type="json_object",
|
||||||
|
user=user,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
|
@ -524,7 +539,7 @@ async def generate_online_subqueries(
|
||||||
|
|
||||||
|
|
||||||
async def schedule_query(
|
async def schedule_query(
|
||||||
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
|
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
|
||||||
) -> Tuple[str, ...]:
|
) -> Tuple[str, ...]:
|
||||||
"""
|
"""
|
||||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||||
|
@ -537,7 +552,7 @@ async def schedule_query(
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
crontime_prompt, query_images=query_images, response_type="json_object", user=user
|
crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
|
@ -552,7 +567,7 @@ async def schedule_query(
|
||||||
|
|
||||||
|
|
||||||
async def extract_relevant_info(
|
async def extract_relevant_info(
|
||||||
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None
|
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
|
||||||
) -> Union[str, None]:
|
) -> Union[str, None]:
|
||||||
"""
|
"""
|
||||||
Extract relevant information for a given query from the target corpus
|
Extract relevant information for a given query from the target corpus
|
||||||
|
@ -575,6 +590,7 @@ async def extract_relevant_info(
|
||||||
extract_relevant_information,
|
extract_relevant_information,
|
||||||
prompts.system_prompt_extract_relevant_information,
|
prompts.system_prompt_extract_relevant_information,
|
||||||
user=user,
|
user=user,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
@ -586,6 +602,7 @@ async def extract_relevant_summary(
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> Union[str, None]:
|
) -> Union[str, None]:
|
||||||
"""
|
"""
|
||||||
Extract relevant information for a given query from the target corpus
|
Extract relevant information for a given query from the target corpus
|
||||||
|
@ -613,6 +630,7 @@ async def extract_relevant_summary(
|
||||||
prompts.system_prompt_extract_relevant_summary,
|
prompts.system_prompt_extract_relevant_summary,
|
||||||
user=user,
|
user=user,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
@ -625,6 +643,7 @@ async def generate_summary_from_files(
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
file_object = None
|
file_object = None
|
||||||
|
@ -653,6 +672,7 @@ async def generate_summary_from_files(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
response_log = str(response)
|
response_log = str(response)
|
||||||
|
|
||||||
|
@ -673,6 +693,7 @@ async def generate_excalidraw_diagram(
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
|
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
|
||||||
|
@ -687,6 +708,7 @@ async def generate_excalidraw_diagram(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
|
@ -697,6 +719,7 @@ async def generate_excalidraw_diagram(
|
||||||
q=better_diagram_description_prompt,
|
q=better_diagram_description_prompt,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield better_diagram_description_prompt, excalidraw_diagram_description
|
yield better_diagram_description_prompt, excalidraw_diagram_description
|
||||||
|
@ -711,6 +734,7 @@ async def generate_better_diagram_description(
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a diagram description from the given query and context
|
Generate a diagram description from the given query and context
|
||||||
|
@ -748,7 +772,7 @@ async def generate_better_diagram_description(
|
||||||
|
|
||||||
with timer("Chat actor: Generate better diagram description", logger):
|
with timer("Chat actor: Generate better diagram description", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
improve_diagram_description_prompt, query_images=query_images, user=user
|
improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer
|
||||||
)
|
)
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
|
@ -761,6 +785,7 @@ async def generate_excalidraw_diagram_from_description(
|
||||||
q: str,
|
q: str,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> str:
|
) -> str:
|
||||||
personality_context = (
|
personality_context = (
|
||||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||||
|
@ -772,7 +797,9 @@ async def generate_excalidraw_diagram_from_description(
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Chat actor: Generate excalidraw diagram", logger):
|
with timer("Chat actor: Generate excalidraw diagram", logger):
|
||||||
raw_response = await send_message_to_model_wrapper(message=excalidraw_diagram_generation, user=user)
|
raw_response = await send_message_to_model_wrapper(
|
||||||
|
message=excalidraw_diagram_generation, user=user, tracer=tracer
|
||||||
|
)
|
||||||
raw_response = raw_response.strip()
|
raw_response = raw_response.strip()
|
||||||
raw_response = remove_json_codeblock(raw_response)
|
raw_response = remove_json_codeblock(raw_response)
|
||||||
response: Dict[str, str] = json.loads(raw_response)
|
response: Dict[str, str] = json.loads(raw_response)
|
||||||
|
@ -793,6 +820,7 @@ async def generate_better_image_prompt(
|
||||||
query_images: Optional[List[str]] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a better image prompt from the given query
|
Generate a better image prompt from the given query
|
||||||
|
@ -839,7 +867,9 @@ async def generate_better_image_prompt(
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||||
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
|
response = await send_message_to_model_wrapper(
|
||||||
|
image_prompt, query_images=query_images, user=user, tracer=tracer
|
||||||
|
)
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
response = response[1:-1]
|
response = response[1:-1]
|
||||||
|
@ -853,6 +883,7 @@ async def send_message_to_model_wrapper(
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
||||||
vision_available = conversation_config.vision_enabled
|
vision_available = conversation_config.vision_enabled
|
||||||
|
@ -899,6 +930,7 @@ async def send_message_to_model_wrapper(
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
|
@ -922,6 +954,7 @@ async def send_message_to_model_wrapper(
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
@ -940,6 +973,7 @@ async def send_message_to_model_wrapper(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
@ -955,7 +989,7 @@ async def send_message_to_model_wrapper(
|
||||||
)
|
)
|
||||||
|
|
||||||
return gemini_send_message_to_model(
|
return gemini_send_message_to_model(
|
||||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
@ -966,6 +1000,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
system_message: str = "",
|
system_message: str = "",
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
|
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
|
||||||
|
|
||||||
|
@ -998,6 +1033,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
|
@ -1012,7 +1048,11 @@ def send_message_to_model_wrapper_sync(
|
||||||
)
|
)
|
||||||
|
|
||||||
openai_response = send_message_to_model(
|
openai_response = send_message_to_model(
|
||||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
messages=truncated_messages,
|
||||||
|
api_key=api_key,
|
||||||
|
model=chat_model,
|
||||||
|
response_type=response_type,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
return openai_response
|
return openai_response
|
||||||
|
@ -1032,6 +1072,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
|
@ -1050,6 +1091,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
@ -1071,6 +1113,7 @@ def generate_chat_response(
|
||||||
user_name: Optional[str] = None,
|
user_name: Optional[str] = None,
|
||||||
meta_research: str = "",
|
meta_research: str = "",
|
||||||
query_images: Optional[List[str]] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
chat_response = None
|
chat_response = None
|
||||||
|
@ -1094,6 +1137,7 @@ def generate_chat_response(
|
||||||
client_application=client_application,
|
client_application=client_application,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||||
|
@ -1120,6 +1164,7 @@ def generate_chat_response(
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
|
@ -1144,6 +1189,7 @@ def generate_chat_response(
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||||
|
@ -1165,6 +1211,7 @@ def generate_chat_response(
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
@ -1184,6 +1231,7 @@ def generate_chat_response(
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata.update({"chat_model": conversation_config.chat_model})
|
metadata.update({"chat_model": conversation_config.chat_model})
|
||||||
|
@ -1540,9 +1588,15 @@ def scheduled_chat(
|
||||||
|
|
||||||
|
|
||||||
async def create_automation(
|
async def create_automation(
|
||||||
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None
|
q: str,
|
||||||
|
timezone: str,
|
||||||
|
user: KhojUser,
|
||||||
|
calling_url: URL,
|
||||||
|
meta_log: dict = {},
|
||||||
|
conversation_id: str = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
crontime, query_to_run, subject = await schedule_query(q, meta_log, user)
|
crontime, query_to_run, subject = await schedule_query(q, meta_log, user, tracer=tracer)
|
||||||
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
|
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
|
||||||
return job, crontime, query_to_run, subject
|
return job, crontime, query_to_run, subject
|
||||||
|
|
||||||
|
|
|
@ -45,6 +45,7 @@ async def apick_next_tool(
|
||||||
previous_iterations_history: str = None,
|
previous_iterations_history: str = None,
|
||||||
max_iterations: int = 5,
|
max_iterations: int = 5,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
|
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
|
||||||
|
@ -93,6 +94,7 @@ async def apick_next_tool(
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
user=user,
|
user=user,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -135,6 +137,7 @@ async def execute_information_collection(
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
location: LocationData = None,
|
location: LocationData = None,
|
||||||
file_filters: List[str] = [],
|
file_filters: List[str] = [],
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
current_iteration = 0
|
current_iteration = 0
|
||||||
MAX_ITERATIONS = 5
|
MAX_ITERATIONS = 5
|
||||||
|
@ -159,6 +162,7 @@ async def execute_information_collection(
|
||||||
previous_iterations_history,
|
previous_iterations_history,
|
||||||
MAX_ITERATIONS,
|
MAX_ITERATIONS,
|
||||||
send_status_func,
|
send_status_func,
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -180,6 +184,7 @@ async def execute_information_collection(
|
||||||
send_status_func,
|
send_status_func,
|
||||||
query_images,
|
query_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -211,6 +216,7 @@ async def execute_information_collection(
|
||||||
max_webpages_to_read=0,
|
max_webpages_to_read=0,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -228,6 +234,7 @@ async def execute_information_collection(
|
||||||
send_status_func,
|
send_status_func,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -258,6 +265,7 @@ async def execute_information_collection(
|
||||||
send_status_func,
|
send_status_func,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
|
Loading…
Reference in a new issue