Add better error handling for diagram output, and fix chat history construct

- Make the `clean_json` method more robust as well
This commit is contained in:
sabaimran 2024-11-11 20:44:19 -08:00
parent 7bd2f83f97
commit 1cab6c081f
4 changed files with 50 additions and 13 deletions

View file

@ -215,6 +215,10 @@ export function getIconForSlashCommand(command: string, customClassName: string
return <PencilLine className={className} />; return <PencilLine className={className} />;
} }
if (command.includes("code")) {
return <Code className={className} />;
}
return <ArrowRight className={className} />; return <ArrowRight className={className} />;
} }

View file

@ -5,6 +5,7 @@ import math
import mimetypes import mimetypes
import os import os
import queue import queue
import re
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
@ -377,8 +378,10 @@ def generate_chatml_messages_with_context(
message_context = "" message_context = ""
message_attached_files = "" message_attached_files = ""
chat_message = chat.get("message")
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""): if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
message_context += chat.get("intent").get("inferred-queries")[0] chat_message = chat["intent"].get("inferred-queries")[0]
if not is_none_or_empty(chat.get("context")): if not is_none_or_empty(chat.get("context")):
references = "\n\n".join( references = "\n\n".join(
{ {
@ -407,7 +410,7 @@ def generate_chatml_messages_with_context(
role = "user" if chat["by"] == "you" else "assistant" role = "user" if chat["by"] == "you" else "assistant"
message_content = construct_structured_message( message_content = construct_structured_message(
chat["message"], chat.get("images"), model_type, vision_enabled, attached_file_context=query_files chat_message, chat.get("images"), model_type, vision_enabled, attached_file_context=query_files
) )
reconstructed_message = ChatMessage(content=message_content, role=role) reconstructed_message = ChatMessage(content=message_content, role=role)
@ -524,7 +527,25 @@ def reciprocal_conversation_to_chatml(message_pair):
def clean_json(response: str): def clean_json(response: str):
"""Remove any markdown json codeblock and newline formatting if present. Useful for non schema enforceable models""" """Remove any markdown json codeblock and newline formatting if present. Useful for non schema enforceable models"""
return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```") try:
# Remove markdown code blocks
cleaned = response.strip().replace("```json", "").replace("```", "")
# Find JSON array/object pattern
json_match = re.search(r"\[.*\]|\{.*\}", cleaned, re.DOTALL)
if not json_match:
return ""
# Extract matched JSON
json_str = json_match.group()
# Validate by parsing
json.loads(json_str)
return json_str.strip()
except (json.JSONDecodeError, AttributeError):
return ""
def clean_code_python(code: str): def clean_code_python(code: str):

View file

@ -1171,8 +1171,13 @@ async def chat(
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
else: else:
better_diagram_description_prompt, excalidraw_diagram_description = result better_diagram_description_prompt, excalidraw_diagram_description = result
inferred_queries.append(better_diagram_description_prompt) if better_diagram_description_prompt and excalidraw_diagram_description:
diagram_description = excalidraw_diagram_description inferred_queries.append(better_diagram_description_prompt)
diagram_description = excalidraw_diagram_description
else:
async for result in send_llm_response(f"Failed to generate diagram. Please try again later."):
yield result
return
content_obj = { content_obj = {
"intentType": intent_type, "intentType": intent_type,

View file

@ -784,13 +784,17 @@ async def generate_excalidraw_diagram(
if send_status_func: if send_status_func:
async for event in send_status_func(f"**Diagram to Create:**:\n{better_diagram_description_prompt}"): async for event in send_status_func(f"**Diagram to Create:**:\n{better_diagram_description_prompt}"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
try:
excalidraw_diagram_description = await generate_excalidraw_diagram_from_description( excalidraw_diagram_description = await generate_excalidraw_diagram_from_description(
q=better_diagram_description_prompt, q=better_diagram_description_prompt,
user=user, user=user,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
) )
except Exception as e:
logger.error(f"Error generating Excalidraw diagram for {user.email}: {e}", exc_info=True)
yield None, None
return
yield better_diagram_description_prompt, excalidraw_diagram_description yield better_diagram_description_prompt, excalidraw_diagram_description
@ -876,7 +880,10 @@ async def generate_excalidraw_diagram_from_description(
query=excalidraw_diagram_generation, user=user, tracer=tracer query=excalidraw_diagram_generation, user=user, tracer=tracer
) )
raw_response = clean_json(raw_response) raw_response = clean_json(raw_response)
response: Dict[str, str] = json.loads(raw_response) try:
response: Dict[str, str] = json.loads(raw_response)
except Exception:
raise AssertionError(f"Invalid response for generating Excalidraw diagram: {raw_response}")
if not response or not isinstance(response, List) or not isinstance(response[0], Dict): if not response or not isinstance(response, List) or not isinstance(response[0], Dict):
# TODO Some additional validation here that it's a valid Excalidraw diagram # TODO Some additional validation here that it's a valid Excalidraw diagram
raise AssertionError(f"Invalid response for improving diagram description: {response}") raise AssertionError(f"Invalid response for improving diagram description: {response}")