mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Encourage Anthropic models to output json object using { prefill
Anthropic API doesn't have ability to enforce response with valid json object, unlike all the other model types. While the model will usually adhere to json output instructions. This step is meant to more strongly encourage it to just output json object when response_type of json_object is requested.
This commit is contained in:
parent
dc8e89b5de
commit
83ca820abe
4 changed files with 22 additions and 11 deletions
|
@ -16,6 +16,7 @@ from khoj.processor.conversation.anthropic.utils import (
|
|||
from khoj.processor.conversation.utils import (
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
remove_json_codeblock,
|
||||
)
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
@ -91,15 +92,13 @@ def extract_questions_anthropic(
|
|||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
response_type="json_object",
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Claude's Response
|
||||
try:
|
||||
response = response.strip()
|
||||
match = re.search(r"\{.*?\}", response)
|
||||
if match:
|
||||
response = match.group()
|
||||
response = remove_json_codeblock(response)
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(response, list) or not response:
|
||||
|
@ -113,7 +112,7 @@ def extract_questions_anthropic(
|
|||
return questions
|
||||
|
||||
|
||||
def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
|
||||
def anthropic_send_message_to_model(messages, api_key, model, response_type="text", tracer={}):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
@ -125,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
|
|||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
|
|
@ -35,7 +35,15 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
|
|||
reraise=True,
|
||||
)
|
||||
def anthropic_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={}
|
||||
messages,
|
||||
system_prompt,
|
||||
model_name,
|
||||
temperature=0,
|
||||
api_key=None,
|
||||
model_kwargs=None,
|
||||
max_tokens=None,
|
||||
response_type="text",
|
||||
tracer={},
|
||||
) -> str:
|
||||
if api_key not in anthropic_clients:
|
||||
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
||||
|
@ -44,8 +52,11 @@ def anthropic_completion_with_backoff(
|
|||
client = anthropic_clients[api_key]
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
if response_type == "json_object":
|
||||
# Prefill model response with '{' to make it output a valid JSON object
|
||||
formatted_messages += [{"role": "assistant", "content": "{"}]
|
||||
|
||||
aggregated_response = ""
|
||||
aggregated_response = "{" if response_type == "json_object" else ""
|
||||
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||
|
||||
model_kwargs = model_kwargs or dict()
|
||||
|
|
|
@ -16,6 +16,7 @@ from khoj.processor.conversation.google.utils import (
|
|||
from khoj.processor.conversation.utils import (
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
remove_json_codeblock,
|
||||
)
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
@ -92,10 +93,7 @@ def extract_questions_gemini(
|
|||
|
||||
# Extract, Clean Message from Gemini's Response
|
||||
try:
|
||||
response = response.strip()
|
||||
match = re.search(r"\{.*?\}", response)
|
||||
if match:
|
||||
response = match.group()
|
||||
response = remove_json_codeblock(response)
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(response, list) or not response:
|
||||
|
|
|
@ -978,6 +978,7 @@ async def send_message_to_model_wrapper(
|
|||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
|
@ -1078,6 +1079,7 @@ def send_message_to_model_wrapper_sync(
|
|||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in a new issue