mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Allow Anthropic API calling functions to save conversation traces
This commit is contained in:
parent
6fcd6a5659
commit
eb6424f14d
2 changed files with 31 additions and 6 deletions
|
@ -34,6 +34,7 @@ def extract_questions_anthropic(
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
|
@ -89,6 +90,7 @@ def extract_questions_anthropic(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from Claude's Response
|
# Extract, Clean Message from Claude's Response
|
||||||
|
@ -110,7 +112,7 @@ def extract_questions_anthropic(
|
||||||
return questions
|
return questions
|
||||||
|
|
||||||
|
|
||||||
def anthropic_send_message_to_model(messages, api_key, model):
|
def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
|
||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
"""
|
"""
|
||||||
|
@ -122,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model):
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,6 +144,7 @@ def converse_anthropic(
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Converse with user using Anthropic's Claude
|
Converse with user using Anthropic's Claude
|
||||||
|
@ -215,4 +219,5 @@ def converse_anthropic(
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
completion_func=completion_func,
|
completion_func=completion_func,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
|
@ -12,8 +12,13 @@ from tenacity import (
|
||||||
wait_random_exponential,
|
wait_random_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
|
from khoj.processor.conversation.utils import (
|
||||||
from khoj.utils.helpers import is_none_or_empty
|
ThreadedGenerator,
|
||||||
|
commit_conversation_trace,
|
||||||
|
get_image_from_url,
|
||||||
|
)
|
||||||
|
from khoj.utils import state
|
||||||
|
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -30,7 +35,7 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def anthropic_completion_with_backoff(
|
def anthropic_completion_with_backoff(
|
||||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
|
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={}
|
||||||
) -> str:
|
) -> str:
|
||||||
if api_key not in anthropic_clients:
|
if api_key not in anthropic_clients:
|
||||||
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
||||||
|
@ -58,6 +63,12 @@ def anthropic_completion_with_backoff(
|
||||||
for text in stream.text_stream:
|
for text in stream.text_stream:
|
||||||
aggregated_response += text
|
aggregated_response += text
|
||||||
|
|
||||||
|
# Save conversation trace
|
||||||
|
tracer["chat_model"] = model_name
|
||||||
|
tracer["temperature"] = temperature
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
|
|
||||||
return aggregated_response
|
return aggregated_response
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,18 +89,19 @@ def anthropic_chat_completion_with_backoff(
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
|
tracer={},
|
||||||
):
|
):
|
||||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||||
t = Thread(
|
t = Thread(
|
||||||
target=anthropic_llm_thread,
|
target=anthropic_llm_thread,
|
||||||
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
|
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs, tracer),
|
||||||
)
|
)
|
||||||
t.start()
|
t.start()
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
def anthropic_llm_thread(
|
def anthropic_llm_thread(
|
||||||
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
|
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None, tracer={}
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if api_key not in anthropic_clients:
|
if api_key not in anthropic_clients:
|
||||||
|
@ -102,6 +114,7 @@ def anthropic_llm_thread(
|
||||||
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
|
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
|
||||||
]
|
]
|
||||||
|
|
||||||
|
aggregated_response = ""
|
||||||
with client.messages.stream(
|
with client.messages.stream(
|
||||||
messages=formatted_messages,
|
messages=formatted_messages,
|
||||||
model=model_name, # type: ignore
|
model=model_name, # type: ignore
|
||||||
|
@ -112,7 +125,14 @@ def anthropic_llm_thread(
|
||||||
**(model_kwargs or dict()),
|
**(model_kwargs or dict()),
|
||||||
) as stream:
|
) as stream:
|
||||||
for text in stream.text_stream:
|
for text in stream.text_stream:
|
||||||
|
aggregated_response += text
|
||||||
g.send(text)
|
g.send(text)
|
||||||
|
|
||||||
|
# Save conversation trace
|
||||||
|
tracer["chat_model"] = model_name
|
||||||
|
tracer["temperature"] = temperature
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
|
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
|
|
Loading…
Reference in a new issue