Allow Anthropic API calling functions to save conversation traces

This commit is contained in:
Debanjum Singh Solanky 2024-10-24 14:26:30 -07:00
parent 6fcd6a5659
commit eb6424f14d
2 changed files with 31 additions and 6 deletions

View file

@ -34,6 +34,7 @@ def extract_questions_anthropic(
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_enabled: bool = False, vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
tracer: dict = {},
): ):
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@ -89,6 +90,7 @@ def extract_questions_anthropic(
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
api_key=api_key, api_key=api_key,
tracer=tracer,
) )
# Extract, Clean Message from Claude's Response # Extract, Clean Message from Claude's Response
@ -110,7 +112,7 @@ def extract_questions_anthropic(
return questions return questions
def anthropic_send_message_to_model(messages, api_key, model): def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
""" """
Send message to model Send message to model
""" """
@ -122,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model):
system_prompt=system_prompt, system_prompt=system_prompt,
model_name=model, model_name=model,
api_key=api_key, api_key=api_key,
tracer=tracer,
) )
@ -141,6 +144,7 @@ def converse_anthropic(
agent: Agent = None, agent: Agent = None,
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
tracer: dict = {},
): ):
""" """
Converse with user using Anthropic's Claude Converse with user using Anthropic's Claude
@ -215,4 +219,5 @@ def converse_anthropic(
system_prompt=system_prompt, system_prompt=system_prompt,
completion_func=completion_func, completion_func=completion_func,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
tracer=tracer,
) )

View file

@ -12,8 +12,13 @@ from tenacity import (
wait_random_exponential, wait_random_exponential,
) )
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url from khoj.processor.conversation.utils import (
from khoj.utils.helpers import is_none_or_empty ThreadedGenerator,
commit_conversation_trace,
get_image_from_url,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,7 +35,7 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
reraise=True, reraise=True,
) )
def anthropic_completion_with_backoff( def anthropic_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={}
) -> str: ) -> str:
if api_key not in anthropic_clients: if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
@ -58,6 +63,12 @@ def anthropic_completion_with_backoff(
for text in stream.text_stream: for text in stream.text_stream:
aggregated_response += text aggregated_response += text
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
return aggregated_response return aggregated_response
@ -78,18 +89,19 @@ def anthropic_chat_completion_with_backoff(
max_prompt_size=None, max_prompt_size=None,
completion_func=None, completion_func=None,
model_kwargs=None, model_kwargs=None,
tracer={},
): ):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread( t = Thread(
target=anthropic_llm_thread, target=anthropic_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs), args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs, tracer),
) )
t.start() t.start()
return g return g
def anthropic_llm_thread( def anthropic_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None, tracer={}
): ):
try: try:
if api_key not in anthropic_clients: if api_key not in anthropic_clients:
@ -102,6 +114,7 @@ def anthropic_llm_thread(
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
] ]
aggregated_response = ""
with client.messages.stream( with client.messages.stream(
messages=formatted_messages, messages=formatted_messages,
model=model_name, # type: ignore model=model_name, # type: ignore
@ -112,7 +125,14 @@ def anthropic_llm_thread(
**(model_kwargs or dict()), **(model_kwargs or dict()),
) as stream: ) as stream:
for text in stream.text_stream: for text in stream.text_stream:
aggregated_response += text
g.send(text) g.send(text)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e: except Exception as e:
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
finally: finally: