From 0685a79748d0b6572b870c66bf7d12863e2cdff3 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 12 Sep 2024 16:41:40 -0700 Subject: [PATCH] Remove any markdown json codeblock in chat actors expecting json responses Strip any json md codeblock wrapper if exists before processing response by output mode, extract questions chat actor. This is similar to what is already being done by other chat actors Useful for succesfully interpreting json output in chat actors when using non (json) schema enforceable models like o1 and gemma-2 Use conversation helper function to centralize the json md codeblock removal code --- src/khoj/processor/conversation/openai/gpt.py | 2 ++ src/khoj/processor/conversation/utils.py | 7 +++++++ src/khoj/routers/helpers.py | 13 ++++++------- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index c6f744fa..90cd4df9 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -14,6 +14,7 @@ from khoj.processor.conversation.openai.utils import ( from khoj.processor.conversation.utils import ( construct_structured_message, generate_chatml_messages_with_context, + remove_json_codeblock, ) from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -85,6 +86,7 @@ def extract_questions( # Extract, Clean Message from GPT's Response try: response = response.strip() + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["queries"] if q.strip()] if not isinstance(response, list) or not response: diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 6444b14d..03bd17a3 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -289,3 +289,10 @@ def truncate_messages( def reciprocal_conversation_to_chatml(message_pair): """Convert a single back and forth between user and assistant to chatml format""" return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])] + + +def remove_json_codeblock(response): + """Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models""" + if response.startswith("```json") and response.endswith("```"): + response = response[7:-3] + return response diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 5687937a..f1b8ddd6 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -88,6 +88,7 @@ from khoj.processor.conversation.openai.gpt import converse, send_message_to_mod from khoj.processor.conversation.utils import ( ThreadedGenerator, generate_chatml_messages_with_context, + remove_json_codeblock, save_to_conversation_log, ) from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled @@ -298,9 +299,7 @@ async def aget_relevant_information_sources( try: response = response.strip() - # Remove any markdown json codeblock formatting if present (useful for gemma-2) - if response.startswith("```json"): - response = response[7:-3] + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["source"] if q.strip()] if not isinstance(response, list) or not response or len(response) == 0: @@ -353,7 +352,9 @@ async def aget_relevant_output_modes( response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object") try: - response = json.loads(response.strip()) + response = response.strip() + response = remove_json_codeblock(response) + response = json.loads(response) if is_none_or_empty(response): return ConversationCommand.Text @@ -433,9 +434,7 @@ async def generate_online_subqueries( # Validate that the response is a non-empty, JSON-serializable list try: response = response.strip() - # Remove any markdown json codeblock formatting if present (useful for gemma-2) - if response.startswith("```json") and response.endswith("```"): - response = response[7:-3] + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["queries"] if q.strip()] if not isinstance(response, list) or not response or len(response) == 0: