mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-20 06:55:08 +00:00
Add better error handling for diagram output, and fix chat history construct
- Make the `clean_json` method more robust as well
This commit is contained in:
parent
7bd2f83f97
commit
1cab6c081f
4 changed files with 50 additions and 13 deletions
|
@ -215,6 +215,10 @@ export function getIconForSlashCommand(command: string, customClassName: string
|
||||||
return <PencilLine className={className} />;
|
return <PencilLine className={className} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (command.includes("code")) {
|
||||||
|
return <Code className={className} />;
|
||||||
|
}
|
||||||
|
|
||||||
return <ArrowRight className={className} />;
|
return <ArrowRight className={className} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import math
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import queue
|
import queue
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -377,8 +378,10 @@ def generate_chatml_messages_with_context(
|
||||||
message_context = ""
|
message_context = ""
|
||||||
message_attached_files = ""
|
message_attached_files = ""
|
||||||
|
|
||||||
|
chat_message = chat.get("message")
|
||||||
|
|
||||||
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
||||||
message_context += chat.get("intent").get("inferred-queries")[0]
|
chat_message = chat["intent"].get("inferred-queries")[0]
|
||||||
if not is_none_or_empty(chat.get("context")):
|
if not is_none_or_empty(chat.get("context")):
|
||||||
references = "\n\n".join(
|
references = "\n\n".join(
|
||||||
{
|
{
|
||||||
|
@ -407,7 +410,7 @@ def generate_chatml_messages_with_context(
|
||||||
|
|
||||||
role = "user" if chat["by"] == "you" else "assistant"
|
role = "user" if chat["by"] == "you" else "assistant"
|
||||||
message_content = construct_structured_message(
|
message_content = construct_structured_message(
|
||||||
chat["message"], chat.get("images"), model_type, vision_enabled, attached_file_context=query_files
|
chat_message, chat.get("images"), model_type, vision_enabled, attached_file_context=query_files
|
||||||
)
|
)
|
||||||
|
|
||||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||||
|
@ -524,7 +527,25 @@ def reciprocal_conversation_to_chatml(message_pair):
|
||||||
|
|
||||||
def clean_json(response: str):
|
def clean_json(response: str):
|
||||||
"""Remove any markdown json codeblock and newline formatting if present. Useful for non schema enforceable models"""
|
"""Remove any markdown json codeblock and newline formatting if present. Useful for non schema enforceable models"""
|
||||||
return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```")
|
try:
|
||||||
|
# Remove markdown code blocks
|
||||||
|
cleaned = response.strip().replace("```json", "").replace("```", "")
|
||||||
|
|
||||||
|
# Find JSON array/object pattern
|
||||||
|
json_match = re.search(r"\[.*\]|\{.*\}", cleaned, re.DOTALL)
|
||||||
|
if not json_match:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Extract matched JSON
|
||||||
|
json_str = json_match.group()
|
||||||
|
|
||||||
|
# Validate by parsing
|
||||||
|
json.loads(json_str)
|
||||||
|
|
||||||
|
return json_str.strip()
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, AttributeError):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def clean_code_python(code: str):
|
def clean_code_python(code: str):
|
||||||
|
|
|
@ -1171,8 +1171,13 @@ async def chat(
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
else:
|
else:
|
||||||
better_diagram_description_prompt, excalidraw_diagram_description = result
|
better_diagram_description_prompt, excalidraw_diagram_description = result
|
||||||
inferred_queries.append(better_diagram_description_prompt)
|
if better_diagram_description_prompt and excalidraw_diagram_description:
|
||||||
diagram_description = excalidraw_diagram_description
|
inferred_queries.append(better_diagram_description_prompt)
|
||||||
|
diagram_description = excalidraw_diagram_description
|
||||||
|
else:
|
||||||
|
async for result in send_llm_response(f"Failed to generate diagram. Please try again later."):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
|
||||||
content_obj = {
|
content_obj = {
|
||||||
"intentType": intent_type,
|
"intentType": intent_type,
|
||||||
|
|
|
@ -784,13 +784,17 @@ async def generate_excalidraw_diagram(
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
async for event in send_status_func(f"**Diagram to Create:**:\n{better_diagram_description_prompt}"):
|
async for event in send_status_func(f"**Diagram to Create:**:\n{better_diagram_description_prompt}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
|
try:
|
||||||
excalidraw_diagram_description = await generate_excalidraw_diagram_from_description(
|
excalidraw_diagram_description = await generate_excalidraw_diagram_from_description(
|
||||||
q=better_diagram_description_prompt,
|
q=better_diagram_description_prompt,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating Excalidraw diagram for {user.email}: {e}", exc_info=True)
|
||||||
|
yield None, None
|
||||||
|
return
|
||||||
|
|
||||||
yield better_diagram_description_prompt, excalidraw_diagram_description
|
yield better_diagram_description_prompt, excalidraw_diagram_description
|
||||||
|
|
||||||
|
@ -876,7 +880,10 @@ async def generate_excalidraw_diagram_from_description(
|
||||||
query=excalidraw_diagram_generation, user=user, tracer=tracer
|
query=excalidraw_diagram_generation, user=user, tracer=tracer
|
||||||
)
|
)
|
||||||
raw_response = clean_json(raw_response)
|
raw_response = clean_json(raw_response)
|
||||||
response: Dict[str, str] = json.loads(raw_response)
|
try:
|
||||||
|
response: Dict[str, str] = json.loads(raw_response)
|
||||||
|
except Exception:
|
||||||
|
raise AssertionError(f"Invalid response for generating Excalidraw diagram: {raw_response}")
|
||||||
if not response or not isinstance(response, List) or not isinstance(response[0], Dict):
|
if not response or not isinstance(response, List) or not isinstance(response[0], Dict):
|
||||||
# TODO Some additional validation here that it's a valid Excalidraw diagram
|
# TODO Some additional validation here that it's a valid Excalidraw diagram
|
||||||
raise AssertionError(f"Invalid response for improving diagram description: {response}")
|
raise AssertionError(f"Invalid response for improving diagram description: {response}")
|
||||||
|
|
Loading…
Add table
Reference in a new issue