From 83ca820abee2b9800345f1cfdaa94d0452919275 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 29 Oct 2024 11:22:29 -0700 Subject: [PATCH] Encourage Anthropic models to output json object using { prefill Anthropic API doesn't have ability to enforce response with valid json object, unlike all the other model types. While the model will usually adhere to json output instructions. This step is meant to more strongly encourage it to just output json object when response_type of json_object is requested. --- .../conversation/anthropic/anthropic_chat.py | 10 +++++----- .../processor/conversation/anthropic/utils.py | 15 +++++++++++++-- .../processor/conversation/google/gemini_chat.py | 6 ++---- src/khoj/routers/helpers.py | 2 ++ 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index f0568ccc..feb587b2 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -16,6 +16,7 @@ from khoj.processor.conversation.anthropic.utils import ( from khoj.processor.conversation.utils import ( construct_structured_message, generate_chatml_messages_with_context, + remove_json_codeblock, ) from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -91,15 +92,13 @@ def extract_questions_anthropic( model_name=model, temperature=temperature, api_key=api_key, + response_type="json_object", tracer=tracer, ) # Extract, Clean Message from Claude's Response try: - response = response.strip() - match = re.search(r"\{.*?\}", response) - if match: - response = match.group() + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["queries"] if q.strip()] if not isinstance(response, list) or not response: @@ -113,7 +112,7 @@ def extract_questions_anthropic( return questions -def anthropic_send_message_to_model(messages, api_key, model, tracer={}): +def anthropic_send_message_to_model(messages, api_key, model, response_type="text", tracer={}): """ Send message to model """ @@ -125,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model, tracer={}): system_prompt=system_prompt, model_name=model, api_key=api_key, + response_type=response_type, tracer=tracer, ) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 6673555b..cdce63c6 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -35,7 +35,15 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000 reraise=True, ) def anthropic_completion_with_backoff( - messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={} + messages, + system_prompt, + model_name, + temperature=0, + api_key=None, + model_kwargs=None, + max_tokens=None, + response_type="text", + tracer={}, ) -> str: if api_key not in anthropic_clients: client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) @@ -44,8 +52,11 @@ def anthropic_completion_with_backoff( client = anthropic_clients[api_key] formatted_messages = [{"role": message.role, "content": message.content} for message in messages] + if response_type == "json_object": + # Prefill model response with '{' to make it output a valid JSON object + formatted_messages += [{"role": "assistant", "content": "{"}] - aggregated_response = "" + aggregated_response = "{" if response_type == "json_object" else "" max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC model_kwargs = model_kwargs or dict() diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 44f24d4b..e9538c0c 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -16,6 +16,7 @@ from khoj.processor.conversation.google.utils import ( from khoj.processor.conversation.utils import ( construct_structured_message, generate_chatml_messages_with_context, + remove_json_codeblock, ) from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -92,10 +93,7 @@ def extract_questions_gemini( # Extract, Clean Message from Gemini's Response try: - response = response.strip() - match = re.search(r"\{.*?\}", response) - if match: - response = match.group() + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["queries"] if q.strip()] if not isinstance(response, list) or not response: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 52958cc1..3fd0aeb1 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -978,6 +978,7 @@ async def send_message_to_model_wrapper( messages=truncated_messages, api_key=api_key, model=chat_model, + response_type=response_type, tracer=tracer, ) elif model_type == ChatModelOptions.ModelType.GOOGLE: @@ -1078,6 +1079,7 @@ def send_message_to_model_wrapper_sync( messages=truncated_messages, api_key=api_key, model=chat_model, + response_type=response_type, tracer=tracer, )