Encourage Anthropic models to output json object using { prefill

Anthropic API doesn't have ability to enforce response with valid json
object, unlike all the other model types.

While the model will usually adhere to json output instructions.

This step is meant to more strongly encourage it to just output json
object when response_type of json_object is requested.
This commit is contained in:
Debanjum 2024-10-29 11:22:29 -07:00
parent dc8e89b5de
commit 83ca820abe
4 changed files with 22 additions and 11 deletions

View file

@ -16,6 +16,7 @@ from khoj.processor.conversation.anthropic.utils import (
from khoj.processor.conversation.utils import (
construct_structured_message,
generate_chatml_messages_with_context,
remove_json_codeblock,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
@ -91,15 +92,13 @@ def extract_questions_anthropic(
model_name=model,
temperature=temperature,
api_key=api_key,
response_type="json_object",
tracer=tracer,
)
# Extract, Clean Message from Claude's Response
try:
response = response.strip()
match = re.search(r"\{.*?\}", response)
if match:
response = match.group()
response = remove_json_codeblock(response)
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
@ -113,7 +112,7 @@ def extract_questions_anthropic(
return questions
def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
def anthropic_send_message_to_model(messages, api_key, model, response_type="text", tracer={}):
"""
Send message to model
"""
@ -125,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
response_type=response_type,
tracer=tracer,
)

View file

@ -35,7 +35,15 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
reraise=True,
)
def anthropic_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={}
messages,
system_prompt,
model_name,
temperature=0,
api_key=None,
model_kwargs=None,
max_tokens=None,
response_type="text",
tracer={},
) -> str:
if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
@ -44,8 +52,11 @@ def anthropic_completion_with_backoff(
client = anthropic_clients[api_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
if response_type == "json_object":
# Prefill model response with '{' to make it output a valid JSON object
formatted_messages += [{"role": "assistant", "content": "{"}]
aggregated_response = ""
aggregated_response = "{" if response_type == "json_object" else ""
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
model_kwargs = model_kwargs or dict()

View file

@ -16,6 +16,7 @@ from khoj.processor.conversation.google.utils import (
from khoj.processor.conversation.utils import (
construct_structured_message,
generate_chatml_messages_with_context,
remove_json_codeblock,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
@ -92,10 +93,7 @@ def extract_questions_gemini(
# Extract, Clean Message from Gemini's Response
try:
response = response.strip()
match = re.search(r"\{.*?\}", response)
if match:
response = match.group()
response = remove_json_codeblock(response)
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:

View file

@ -978,6 +978,7 @@ async def send_message_to_model_wrapper(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.GOOGLE:
@ -1078,6 +1079,7 @@ def send_message_to_model_wrapper_sync(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
)