Allow Anthropic API calling functions to save conversation traces

This commit is contained in:
Debanjum Singh Solanky 2024-10-24 14:26:30 -07:00
parent 6fcd6a5659
commit eb6424f14d
2 changed files with 31 additions and 6 deletions

View file

@ -34,6 +34,7 @@ def extract_questions_anthropic(
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
tracer: dict = {},
):
"""
Infer search queries to retrieve relevant notes to answer user query
@ -89,6 +90,7 @@ def extract_questions_anthropic(
model_name=model,
temperature=temperature,
api_key=api_key,
tracer=tracer,
)
# Extract, Clean Message from Claude's Response
@ -110,7 +112,7 @@ def extract_questions_anthropic(
return questions
def anthropic_send_message_to_model(messages, api_key, model):
def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
"""
Send message to model
"""
@ -122,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model):
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
tracer=tracer,
)
@ -141,6 +144,7 @@ def converse_anthropic(
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
tracer: dict = {},
):
"""
Converse with user using Anthropic's Claude
@ -215,4 +219,5 @@ def converse_anthropic(
system_prompt=system_prompt,
completion_func=completion_func,
max_prompt_size=max_prompt_size,
tracer=tracer,
)

View file

@ -12,8 +12,13 @@ from tenacity import (
wait_random_exponential,
)
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
from khoj.utils.helpers import is_none_or_empty
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
get_image_from_url,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__)
@ -30,7 +35,7 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
reraise=True,
)
def anthropic_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={}
) -> str:
if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
@ -58,6 +63,12 @@ def anthropic_completion_with_backoff(
for text in stream.text_stream:
aggregated_response += text
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
return aggregated_response
@ -78,18 +89,19 @@ def anthropic_chat_completion_with_backoff(
max_prompt_size=None,
completion_func=None,
model_kwargs=None,
tracer={},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=anthropic_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs, tracer),
)
t.start()
return g
def anthropic_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None, tracer={}
):
try:
if api_key not in anthropic_clients:
@ -102,6 +114,7 @@ def anthropic_llm_thread(
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
]
aggregated_response = ""
with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
@ -112,7 +125,14 @@ def anthropic_llm_thread(
**(model_kwargs or dict()),
) as stream:
for text in stream.text_stream:
aggregated_response += text
g.send(text)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
finally: