Enforce json response by offline models when requested by chat actors

- Background
  Llama.cpp allows enforcing response as json object similar to OpenAI
  API. Pass expected response format to offline chat models as well.

- Overview
  Enforce json output to improve intermediate step performance by
  offline chat models. This is especially helpful when working with
  smaller models like Phi-3.5-mini and Gemma-2 2B, that do not
  consistently respond with structured output, even when requested

- Details
  Enforce json response by extract questions, infer output offline
  chat actors
  - Convert prompts to output json objects when offline chat models
    extract document search questions or infer output mode
  - Make llama.cpp enforce response as json object

- Result
  - Improve all intermediate steps by offline chat actors via json
    response enforcement
  - Avoid the manual, ad-hoc and flaky output schema enforcement and
    simplify the code
This commit is contained in:
Debanjum Singh Solanky 2024-08-22 16:54:24 -07:00
parent ab7fb5117c
commit 8a4c20d59a
3 changed files with 30 additions and 36 deletions

View file

@ -87,26 +87,16 @@ def extract_questions_offline(
model=model, model=model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
temperature=temperature, temperature=temperature,
response_type="json_object",
) )
finally: finally:
state.chat_lock.release() state.chat_lock.release()
# Extract, Clean Message from GPT's Response # Extract and clean the chat model's response
try: try:
# This will expect to be a list with a single string with a list of questions response = response.strip(empty_escape_sequences)
questions_str = ( response = json.loads(response)
str(response) questions = [q.strip() for q in response["queries"] if q.strip()]
.strip(empty_escape_sequences)
.replace("['", '["')
.replace("<s>", "")
.replace("</s>", "")
.replace("']", '"]')
.replace("', '", '", "')
)
# Remove any markdown json codeblock formatting if present (useful for gemma-2)
if response.startswith("```json"):
response = response[7:-3]
questions: List[str] = json.loads(questions_str)
questions = filter_questions(questions) questions = filter_questions(questions)
except: except:
logger.warning(f"Llama returned invalid JSON. Falling back to using user message as search query.\n{response}") logger.warning(f"Llama returned invalid JSON. Falling back to using user message as search query.\n{response}")
@ -245,12 +235,13 @@ def send_message_to_model_offline(
streaming=False, streaming=False,
stop=[], stop=[],
max_prompt_size: int = None, max_prompt_size: int = None,
response_type: str = "text",
): ):
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
messages_dict = [{"role": message.role, "content": message.content} for message in messages] messages_dict = [{"role": message.role, "content": message.content} for message in messages]
response = offline_chat_model.create_chat_completion( response = offline_chat_model.create_chat_completion(
messages_dict, stop=stop, stream=streaming, temperature=temperature messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
) )
if streaming: if streaming:
return response return response

View file

@ -217,31 +217,31 @@ User's Location: {location}
Examples: Examples:
Q: How was my trip to Cambodia? Q: How was my trip to Cambodia?
Khoj: ["How was my trip to Cambodia?"] Khoj: {{"queries": ["How was my trip to Cambodia?"]}}
Q: Who did I visit the temple with on that trip? Q: Who did I visit the temple with on that trip?
Khoj: ["Who did I visit the temple with in Cambodia?"] Khoj: {{"queries": ["Who did I visit the temple with in Cambodia?"]}}
Q: Which of them is older? Q: Which of them is older?
Khoj: ["When was Alice born?", "What is Bob's age?"] Khoj: {{"queries": ["When was Alice born?", "What is Bob's age?"]}}
Q: Where did John say he was? He mentioned it in our call last week. Q: Where did John say he was? He mentioned it in our call last week.
Khoj: ["Where is John? dt>='{last_year}-12-25' dt<'{last_year}-12-26'", "John's location in call notes"] Khoj: {{"queries": ["Where is John? dt>='{last_year}-12-25' dt<'{last_year}-12-26'", "John's location in call notes"]}}
Q: How can you help me? Q: How can you help me?
Khoj: ["Social relationships", "Physical and mental health", "Education and career", "Personal life goals and habits"] Khoj: {{"queries": ["Social relationships", "Physical and mental health", "Education and career", "Personal life goals and habits"]}}
Q: What did I do for Christmas last year? Q: What did I do for Christmas last year?
Khoj: ["What did I do for Christmas {last_year} dt>='{last_year}-12-25' dt<'{last_year}-12-26'"] Khoj: {{"queries": ["What did I do for Christmas {last_year} dt>='{last_year}-12-25' dt<'{last_year}-12-26'"]}}
Q: How should I take care of my plants? Q: How should I take care of my plants?
Khoj: ["What kind of plants do I have?", "What issues do my plants have?"] Khoj: {{"queries": ["What kind of plants do I have?", "What issues do my plants have?"]}}
Q: Who all did I meet here yesterday? Q: Who all did I meet here yesterday?
Khoj: ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"] Khoj: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}}
Q: Share some random, interesting experiences from this month Q: Share some random, interesting experiences from this month
Khoj: ["Exciting travel adventures from {current_month}", "Fun social events dt>='{current_month}-01' dt<'{current_date}'", "Intense emotional experiences in {current_month}"] Khoj: {{"queries": ["Exciting travel adventures from {current_month}", "Fun social events dt>='{current_month}-01' dt<'{current_date}'", "Intense emotional experiences in {current_month}"]}}
Chat History: Chat History:
{chat_history} {chat_history}
@ -425,7 +425,7 @@ User: I just visited Jerusalem for the first time. Pull up my notes from the tri
AI: You mention visiting Masjid Al-Aqsa and the Western Wall. You also mention trying the local cuisine and visiting the Dead Sea. AI: You mention visiting Masjid Al-Aqsa and the Western Wall. You also mention trying the local cuisine and visiting the Dead Sea.
Q: Draw a picture of my trip to Jerusalem. Q: Draw a picture of my trip to Jerusalem.
Khoj: image Khoj: {{"output": "image"}}
Example: Example:
Chat History: Chat History:
@ -433,7 +433,7 @@ User: I'm having trouble deciding which laptop to get. I want something with at
AI: I can help with that. I see online that there is a new model of the Dell XPS 15 that meets your requirements. AI: I can help with that. I see online that there is a new model of the Dell XPS 15 that meets your requirements.
Q: What are the specs of the new Dell XPS 15? Q: What are the specs of the new Dell XPS 15?
Khoj: text Khoj: {{"output": "text"}}
Example: Example:
Chat History: Chat History:
@ -441,7 +441,7 @@ User: Where did I go on my last vacation?
AI: You went to Jordan and visited Petra, the Dead Sea, and Wadi Rum. AI: You went to Jordan and visited Petra, the Dead Sea, and Wadi Rum.
Q: Remind me who did I go with on that trip? Q: Remind me who did I go with on that trip?
Khoj: text Khoj: {{"output": "text"}}
Example: Example:
Chat History: Chat History:
@ -449,9 +449,9 @@ User: How's the weather outside? Current Location: Bali, Indonesia
AI: It's currently 28°C and partly cloudy in Bali. AI: It's currently 28°C and partly cloudy in Bali.
Q: Share a painting using the weather for Bali every morning. Q: Share a painting using the weather for Bali every morning.
Khoj: automation Khoj: {{"output": "automation"}}
Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a string. Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON.
Chat History: Chat History:
{chat_history} {chat_history}

View file

@ -326,22 +326,23 @@ async def aget_relevant_output_modes(query: str, conversation_history: dict, is_
) )
with timer("Chat actor: Infer output mode for chat response", logger): with timer("Chat actor: Infer output mode for chat response", logger):
response = await send_message_to_model_wrapper(relevant_mode_prompt) response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object")
try: try:
response = response.strip().strip('"') response = json.loads(response.strip())
if is_none_or_empty(response): if is_none_or_empty(response):
return ConversationCommand.Text return ConversationCommand.Text
if response in mode_options.keys(): output_mode = response["output"]
if output_mode in mode_options.keys():
# Check whether the tool exists as a valid ConversationCommand # Check whether the tool exists as a valid ConversationCommand
return ConversationCommand(response) return ConversationCommand(output_mode)
logger.error(f"Invalid output mode selected: {response}. Defaulting to text.") logger.error(f"Invalid output mode selected: {output_mode}. Defaulting to text.")
return ConversationCommand.Text return ConversationCommand.Text
except Exception: except Exception:
logger.error(f"Invalid response for determining relevant mode: {response}") logger.error(f"Invalid response for determining output mode: {response}")
return ConversationCommand.Text return ConversationCommand.Text
@ -595,6 +596,7 @@ async def send_message_to_model_wrapper(
loaded_model=loaded_model, loaded_model=loaded_model,
model=chat_model, model=chat_model,
streaming=False, streaming=False,
response_type=response_type,
) )
elif conversation_config.model_type == "openai": elif conversation_config.model_type == "openai":
@ -664,6 +666,7 @@ def send_message_to_model_wrapper_sync(
loaded_model=loaded_model, loaded_model=loaded_model,
model=chat_model, model=chat_model,
streaming=False, streaming=False,
response_type=response_type,
) )
elif conversation_config.model_type == "openai": elif conversation_config.model_type == "openai":