mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Allow Anthropic API calling functions to save conversation traces
This commit is contained in:
parent
6fcd6a5659
commit
eb6424f14d
2 changed files with 31 additions and 6 deletions
|
@ -34,6 +34,7 @@ def extract_questions_anthropic(
|
|||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -89,6 +90,7 @@ def extract_questions_anthropic(
|
|||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Claude's Response
|
||||
|
@ -110,7 +112,7 @@ def extract_questions_anthropic(
|
|||
return questions
|
||||
|
||||
|
||||
def anthropic_send_message_to_model(messages, api_key, model):
|
||||
def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
@ -122,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model):
|
|||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
|
@ -141,6 +144,7 @@ def converse_anthropic(
|
|||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Converse with user using Anthropic's Claude
|
||||
|
@ -215,4 +219,5 @@ def converse_anthropic(
|
|||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
|
|
@ -12,8 +12,13 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -30,7 +35,7 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
|
|||
reraise=True,
|
||||
)
|
||||
def anthropic_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={}
|
||||
) -> str:
|
||||
if api_key not in anthropic_clients:
|
||||
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
||||
|
@ -58,6 +63,12 @@ def anthropic_completion_with_backoff(
|
|||
for text in stream.text_stream:
|
||||
aggregated_response += text
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
return aggregated_response
|
||||
|
||||
|
||||
|
@ -78,18 +89,19 @@ def anthropic_chat_completion_with_backoff(
|
|||
max_prompt_size=None,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
tracer={},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=anthropic_llm_thread,
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs, tracer),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def anthropic_llm_thread(
|
||||
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
|
||||
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None, tracer={}
|
||||
):
|
||||
try:
|
||||
if api_key not in anthropic_clients:
|
||||
|
@ -102,6 +114,7 @@ def anthropic_llm_thread(
|
|||
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
|
||||
]
|
||||
|
||||
aggregated_response = ""
|
||||
with client.messages.stream(
|
||||
messages=formatted_messages,
|
||||
model=model_name, # type: ignore
|
||||
|
@ -112,7 +125,14 @@ def anthropic_llm_thread(
|
|||
**(model_kwargs or dict()),
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
aggregated_response += text
|
||||
g.send(text)
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
|
||||
finally:
|
||||
|
|
Loading…
Reference in a new issue