Handle \n, dedupe json cleaning into single function for reusability

Use placeholder for newline in json object values until json parsed
and values extracted. This is useful when research mode models outputs
multi-line codeblocks in queries etc.
This commit is contained in:
Debanjum 2024-10-29 11:47:11 -07:00
parent 83ca820abe
commit 86ffd7a7a2
7 changed files with 19 additions and 27 deletions

View file

@ -14,9 +14,9 @@ from khoj.processor.conversation.anthropic.utils import (
format_messages_for_anthropic, format_messages_for_anthropic,
) )
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
clean_json,
construct_structured_message, construct_structured_message,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
remove_json_codeblock,
) )
from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
@ -98,7 +98,7 @@ def extract_questions_anthropic(
# Extract, Clean Message from Claude's Response # Extract, Clean Message from Claude's Response
try: try:
response = remove_json_codeblock(response) response = clean_json(response)
response = json.loads(response) response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()] response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response: if not isinstance(response, list) or not response:

View file

@ -14,9 +14,9 @@ from khoj.processor.conversation.google.utils import (
gemini_completion_with_backoff, gemini_completion_with_backoff,
) )
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
clean_json,
construct_structured_message, construct_structured_message,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
remove_json_codeblock,
) )
from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
@ -93,7 +93,7 @@ def extract_questions_gemini(
# Extract, Clean Message from Gemini's Response # Extract, Clean Message from Gemini's Response
try: try:
response = remove_json_codeblock(response) response = clean_json(response)
response = json.loads(response) response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()] response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response: if not isinstance(response, list) or not response:

View file

@ -12,9 +12,9 @@ from khoj.processor.conversation.openai.utils import (
completion_with_backoff, completion_with_backoff,
) )
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
clean_json,
construct_structured_message, construct_structured_message,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
remove_json_codeblock,
) )
from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
@ -95,8 +95,7 @@ def extract_questions(
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
try: try:
response = response.strip() response = clean_json(response)
response = remove_json_codeblock(response)
response = json.loads(response) response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()] response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response: if not isinstance(response, list) or not response:

View file

@ -442,9 +442,9 @@ def reciprocal_conversation_to_chatml(message_pair):
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])] return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
def remove_json_codeblock(response: str): def clean_json(response: str):
"""Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models""" """Remove any markdown json codeblock and newline formatting if present. Useful for non schema enforceable models"""
return response.removeprefix("```json").removesuffix("```") return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```")
def defilter_query(query: str): def defilter_query(query: str):

View file

@ -12,8 +12,8 @@ from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ChatEvent, ChatEvent,
clean_json,
construct_chat_history, construct_chat_history,
remove_json_codeblock,
) )
from khoj.routers.helpers import send_message_to_model_wrapper from khoj.routers.helpers import send_message_to_model_wrapper
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
@ -111,8 +111,7 @@ async def generate_python_code(
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
response = response.strip() response = clean_json(response)
response = remove_json_codeblock(response)
response = json.loads(response) response = json.loads(response)
codes = [code.strip() for code in response["codes"] if code.strip()] codes = [code.strip() for code in response["codes"] if code.strip()]

View file

@ -90,9 +90,9 @@ from khoj.processor.conversation.openai.gpt import converse, send_message_to_mod
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ChatEvent, ChatEvent,
ThreadedGenerator, ThreadedGenerator,
clean_json,
construct_chat_history, construct_chat_history,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
remove_json_codeblock,
save_to_conversation_log, save_to_conversation_log,
) )
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
@ -334,8 +334,7 @@ async def aget_relevant_information_sources(
) )
try: try:
response = response.strip() response = clean_json(response)
response = remove_json_codeblock(response)
response = json.loads(response) response = json.loads(response)
response = [q.strip() for q in response["source"] if q.strip()] response = [q.strip() for q in response["source"] if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0: if not isinstance(response, list) or not response or len(response) == 0:
@ -413,8 +412,7 @@ async def aget_relevant_output_modes(
) )
try: try:
response = response.strip() response = clean_json(response)
response = remove_json_codeblock(response)
response = json.loads(response) response = json.loads(response)
if is_none_or_empty(response): if is_none_or_empty(response):
@ -475,8 +473,7 @@ async def infer_webpage_urls(
# Validate that the response is a non-empty, JSON-serializable list of URLs # Validate that the response is a non-empty, JSON-serializable list of URLs
try: try:
response = response.strip() response = clean_json(response)
response = remove_json_codeblock(response)
urls = json.loads(response) urls = json.loads(response)
valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)} valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)}
if is_none_or_empty(valid_unique_urls): if is_none_or_empty(valid_unique_urls):
@ -527,8 +524,7 @@ async def generate_online_subqueries(
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
try: try:
response = response.strip() response = clean_json(response)
response = remove_json_codeblock(response)
response = json.loads(response) response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()] response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0: if not isinstance(response, list) or not response or len(response) == 0:
@ -801,8 +797,7 @@ async def generate_excalidraw_diagram_from_description(
raw_response = await send_message_to_model_wrapper( raw_response = await send_message_to_model_wrapper(
query=excalidraw_diagram_generation, user=user, tracer=tracer query=excalidraw_diagram_generation, user=user, tracer=tracer
) )
raw_response = raw_response.strip() raw_response = clean_json(raw_response)
raw_response = remove_json_codeblock(raw_response)
response: Dict[str, str] = json.loads(raw_response) response: Dict[str, str] = json.loads(raw_response)
if not response or not isinstance(response, List) or not isinstance(response[0], Dict): if not response or not isinstance(response, List) or not isinstance(response[0], Dict):
# TODO Some additional validation here that it's a valid Excalidraw diagram # TODO Some additional validation here that it's a valid Excalidraw diagram

View file

@ -11,9 +11,9 @@ from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
InformationCollectionIteration, InformationCollectionIteration,
clean_json,
construct_iteration_history, construct_iteration_history,
construct_tool_chat_history, construct_tool_chat_history,
remove_json_codeblock,
) )
from khoj.processor.tools.online_search import read_webpages, search_online from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.run_code import run_code from khoj.processor.tools.run_code import run_code
@ -99,8 +99,7 @@ async def apick_next_tool(
) )
try: try:
response = response.strip() response = clean_json(response)
response = remove_json_codeblock(response)
response = json.loads(response) response = json.loads(response)
selected_tool = response.get("tool", None) selected_tool = response.get("tool", None)
generated_query = response.get("query", None) generated_query = response.get("query", None)