Remove any markdown json codeblock in chat actors expecting json responses

Strip any json md codeblock wrapper if exists before processing
response by output mode, extract questions chat actor. This is similar
to what is already being done by other chat actors

Useful for succesfully interpreting json output in chat actors when
using non (json) schema enforceable models like o1 and gemma-2

Use conversation helper function to centralize the json md codeblock
removal code
This commit is contained in:
Debanjum Singh Solanky 2024-09-12 16:41:40 -07:00
parent 6e660d11c9
commit 0685a79748
3 changed files with 15 additions and 7 deletions

View file

@ -14,6 +14,7 @@ from khoj.processor.conversation.openai.utils import (
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
construct_structured_message, construct_structured_message,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
remove_json_codeblock,
) )
from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
@ -85,6 +86,7 @@ def extract_questions(
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
try: try:
response = response.strip() response = response.strip()
response = remove_json_codeblock(response)
response = json.loads(response) response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()] response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response: if not isinstance(response, list) or not response:

View file

@ -289,3 +289,10 @@ def truncate_messages(
def reciprocal_conversation_to_chatml(message_pair): def reciprocal_conversation_to_chatml(message_pair):
"""Convert a single back and forth between user and assistant to chatml format""" """Convert a single back and forth between user and assistant to chatml format"""
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])] return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
def remove_json_codeblock(response):
"""Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3]
return response

View file

@ -88,6 +88,7 @@ from khoj.processor.conversation.openai.gpt import converse, send_message_to_mod
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ThreadedGenerator, ThreadedGenerator,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
remove_json_codeblock,
save_to_conversation_log, save_to_conversation_log,
) )
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
@ -298,9 +299,7 @@ async def aget_relevant_information_sources(
try: try:
response = response.strip() response = response.strip()
# Remove any markdown json codeblock formatting if present (useful for gemma-2) response = remove_json_codeblock(response)
if response.startswith("```json"):
response = response[7:-3]
response = json.loads(response) response = json.loads(response)
response = [q.strip() for q in response["source"] if q.strip()] response = [q.strip() for q in response["source"] if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0: if not isinstance(response, list) or not response or len(response) == 0:
@ -353,7 +352,9 @@ async def aget_relevant_output_modes(
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object") response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object")
try: try:
response = json.loads(response.strip()) response = response.strip()
response = remove_json_codeblock(response)
response = json.loads(response)
if is_none_or_empty(response): if is_none_or_empty(response):
return ConversationCommand.Text return ConversationCommand.Text
@ -433,9 +434,7 @@ async def generate_online_subqueries(
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
try: try:
response = response.strip() response = response.strip()
# Remove any markdown json codeblock formatting if present (useful for gemma-2) response = remove_json_codeblock(response)
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3]
response = json.loads(response) response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()] response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0: if not isinstance(response, list) or not response or len(response) == 0: