diff --git a/src/interface/web/app/common/chatFunctions.ts b/src/interface/web/app/common/chatFunctions.ts
index 6585b4c9..c64e81ba 100644
--- a/src/interface/web/app/common/chatFunctions.ts
+++ b/src/interface/web/app/common/chatFunctions.ts
@@ -1,3 +1,4 @@
+import { AttachedFileText } from "../components/chatInputArea/chatInputArea";
import {
CodeContext,
Context,
@@ -16,6 +17,12 @@ export interface MessageMetadata {
turnId: string;
}
+export interface GeneratedAssetsData {
+ images: string[];
+ excalidrawDiagram: string;
+ files: AttachedFileText[];
+}
+
export interface ResponseWithIntent {
intentType: string;
response: string;
@@ -84,6 +91,8 @@ export function processMessageChunk(
if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext, codeContext };
+ console.log(`chunk type: ${chunk.type}`);
+
if (chunk.type === "status") {
console.log(`status: ${chunk.data}`);
const statusMessage = chunk.data as string;
@@ -98,6 +107,20 @@ export function processMessageChunk(
} else if (chunk.type === "metadata") {
const messageMetadata = chunk.data as MessageMetadata;
currentMessage.turnId = messageMetadata.turnId;
+ } else if (chunk.type === "generated_assets") {
+ const generatedAssets = chunk.data as GeneratedAssetsData;
+
+ if (generatedAssets.images) {
+ currentMessage.generatedImages = generatedAssets.images;
+ }
+
+ if (generatedAssets.excalidrawDiagram) {
+ currentMessage.generatedExcalidrawDiagram = generatedAssets.excalidrawDiagram;
+ }
+
+ if (generatedAssets.files) {
+ currentMessage.generatedFiles = generatedAssets.files;
+ }
} else if (chunk.type === "message") {
const chunkData = chunk.data;
// Here, handle if the response is a JSON response with an image, but the intentType is excalidraw
diff --git a/src/interface/web/app/components/chatHistory/chatHistory.tsx b/src/interface/web/app/components/chatHistory/chatHistory.tsx
index ea566df4..c5fd8b54 100644
--- a/src/interface/web/app/components/chatHistory/chatHistory.tsx
+++ b/src/interface/web/app/components/chatHistory/chatHistory.tsx
@@ -54,6 +54,12 @@ function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
const lastIndex = props.trainOfThought.length - 1;
const [collapsed, setCollapsed] = useState(props.completed);
+ // useEffect(() => {
+ // if (props.completed) {
+ // setCollapsed(true);
+ // }
+ // }), [props.completed];
+
return (
void;
conversationId: string;
turnId?: string;
+ generatedImage?: string;
+ excalidrawDiagram?: string;
+ generatedFiles?: AttachedFileText[];
}
interface TrainOfThoughtProps {
@@ -394,6 +402,10 @@ const ChatMessage = forwardRef((props, ref) =>
setExcalidrawData(props.chatMessage.message);
}
+ if (props.chatMessage.excalidrawDiagram) {
+ setExcalidrawData(props.chatMessage.excalidrawDiagram);
+ }
+
// Replace LaTeX delimiters with placeholders
message = message
.replace(/\\\(/g, "LEFTPAREN")
diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index 8f6c5c78..ae6a729d 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -1,7 +1,9 @@
+import logging
import os
import re
import uuid
from random import choice
+from typing import Dict, List, Optional, Union
from django.contrib.auth.models import AbstractUser
from django.contrib.postgres.fields import ArrayField
@@ -11,9 +13,109 @@ from django.db.models.signals import pre_save
from django.dispatch import receiver
from pgvector.django import VectorField
from phonenumber_field.modelfields import PhoneNumberField
+from pydantic import BaseModel as PydanticBaseModel
+from pydantic import Field
+
+logger = logging.getLogger(__name__)
-class BaseModel(models.Model):
+# Pydantic models for type Chat Message validation
+class Context(PydanticBaseModel):
+ compiled: str
+ file: str
+
+
+class CodeContextFile(PydanticBaseModel):
+ filename: str
+ b64_data: str
+
+
+class CodeContextResult(PydanticBaseModel):
+ success: bool
+ output_files: List[CodeContextFile]
+ std_out: str
+ std_err: str
+ code_runtime: int
+
+
+class CodeContextData(PydanticBaseModel):
+ code: str
+ result: CodeContextResult
+
+
+class WebPage(PydanticBaseModel):
+ link: str
+ query: Optional[str] = None
+ snippet: str
+
+
+class AnswerBox(PydanticBaseModel):
+ link: str
+ snippet: str
+ title: str
+ snippetHighlighted: List[str]
+
+
+class PeopleAlsoAsk(PydanticBaseModel):
+ link: str
+ question: str
+ snippet: str
+ title: str
+
+
+class KnowledgeGraph(PydanticBaseModel):
+ attributes: Dict[str, str]
+ description: str
+ descriptionLink: str
+ descriptionSource: str
+ imageUrl: str
+ title: str
+ type: str
+
+
+class OrganicContext(PydanticBaseModel):
+ snippet: str
+ title: str
+ link: str
+
+
+class OnlineContext(PydanticBaseModel):
+ webpages: Optional[Union[WebPage, List[WebPage]]] = None
+ answerBox: Optional[AnswerBox] = None
+ peopleAlsoAsk: Optional[List[PeopleAlsoAsk]] = None
+ knowledgeGraph: Optional[KnowledgeGraph] = None
+ organicContext: Optional[List[OrganicContext]] = None
+
+
+class Intent(PydanticBaseModel):
+ type: str
+ query: str
+ memory_type: str = Field(alias="memory-type")
+ inferred_queries: Optional[List[str]] = Field(default=None, alias="inferred-queries")
+
+
+class TrainOfThought(PydanticBaseModel):
+ type: str
+ data: str
+
+
+class ChatMessage(PydanticBaseModel):
+ message: str
+ trainOfThought: List[TrainOfThought] = []
+ context: List[Context] = []
+ onlineContext: Dict[str, OnlineContext] = {}
+ codeContext: Dict[str, CodeContextData] = {}
+ created: str
+ images: Optional[List[str]] = None
+ queryFiles: Optional[List[Dict]] = None
+ excalidrawDiagram: Optional[str] = None
+ by: str
+ turnId: Optional[str]
+ intent: Optional[Intent] = None
+ automationId: Optional[str] = None
+
+
+class DbBaseModel(models.Model):
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
@@ -21,7 +123,7 @@ class BaseModel(models.Model):
abstract = True
-class ClientApplication(BaseModel):
+class ClientApplication(DbBaseModel):
name = models.CharField(max_length=200)
client_id = models.CharField(max_length=200)
client_secret = models.CharField(max_length=200)
@@ -67,7 +169,7 @@ class KhojApiUser(models.Model):
accessed_at = models.DateTimeField(null=True, default=None)
-class Subscription(BaseModel):
+class Subscription(DbBaseModel):
class Type(models.TextChoices):
TRIAL = "trial"
STANDARD = "standard"
@@ -79,13 +181,13 @@ class Subscription(BaseModel):
enabled_trial_at = models.DateTimeField(null=True, default=None, blank=True)
-class OpenAIProcessorConversationConfig(BaseModel):
+class OpenAIProcessorConversationConfig(DbBaseModel):
name = models.CharField(max_length=200)
api_key = models.CharField(max_length=200)
api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True)
-class ChatModelOptions(BaseModel):
+class ChatModelOptions(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
@@ -103,12 +205,12 @@ class ChatModelOptions(BaseModel):
)
-class VoiceModelOption(BaseModel):
+class VoiceModelOption(DbBaseModel):
model_id = models.CharField(max_length=200)
name = models.CharField(max_length=200)
-class Agent(BaseModel):
+class Agent(DbBaseModel):
class StyleColorTypes(models.TextChoices):
BLUE = "blue"
GREEN = "green"
@@ -208,7 +310,7 @@ class Agent(BaseModel):
super().save(*args, **kwargs)
-class ProcessLock(BaseModel):
+class ProcessLock(DbBaseModel):
class Operation(models.TextChoices):
INDEX_CONTENT = "index_content"
SCHEDULED_JOB = "scheduled_job"
@@ -231,24 +333,24 @@ def verify_agent(sender, instance, **kwargs):
raise ValidationError(f"A private Agent with the name {instance.name} already exists.")
-class NotionConfig(BaseModel):
+class NotionConfig(DbBaseModel):
token = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
-class GithubConfig(BaseModel):
+class GithubConfig(DbBaseModel):
pat_token = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
-class GithubRepoConfig(BaseModel):
+class GithubRepoConfig(DbBaseModel):
name = models.CharField(max_length=200)
owner = models.CharField(max_length=200)
branch = models.CharField(max_length=200)
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
-class WebScraper(BaseModel):
+class WebScraper(DbBaseModel):
class WebScraperType(models.TextChoices):
FIRECRAWL = "Firecrawl"
OLOSTEP = "Olostep"
@@ -321,7 +423,7 @@ class WebScraper(BaseModel):
super().save(*args, **kwargs)
-class ServerChatSettings(BaseModel):
+class ServerChatSettings(DbBaseModel):
chat_default = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
)
@@ -333,35 +435,35 @@ class ServerChatSettings(BaseModel):
)
-class LocalOrgConfig(BaseModel):
+class LocalOrgConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
-class LocalMarkdownConfig(BaseModel):
+class LocalMarkdownConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
-class LocalPdfConfig(BaseModel):
+class LocalPdfConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
-class LocalPlaintextConfig(BaseModel):
+class LocalPlaintextConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
-class SearchModelConfig(BaseModel):
+class SearchModelConfig(DbBaseModel):
class ModelType(models.TextChoices):
TEXT = "text"
@@ -393,7 +495,7 @@ class SearchModelConfig(BaseModel):
bi_encoder_confidence_threshold = models.FloatField(default=0.18)
-class TextToImageModelConfig(BaseModel):
+class TextToImageModelConfig(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
STABILITYAI = "stability-ai"
@@ -430,7 +532,7 @@ class TextToImageModelConfig(BaseModel):
super().save(*args, **kwargs)
-class SpeechToTextModelOptions(BaseModel):
+class SpeechToTextModelOptions(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
@@ -439,22 +541,22 @@ class SpeechToTextModelOptions(BaseModel):
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
-class UserConversationConfig(BaseModel):
+class UserConversationConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
-class UserVoiceModelConfig(BaseModel):
+class UserVoiceModelConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
-class UserTextToImageModelConfig(BaseModel):
+class UserTextToImageModelConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
-class Conversation(BaseModel):
+class Conversation(DbBaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
@@ -468,8 +570,39 @@ class Conversation(BaseModel):
file_filters = models.JSONField(default=list)
id = models.UUIDField(default=uuid.uuid4, editable=False, unique=True, primary_key=True, db_index=True)
+ def clean(self):
+ # Validate conversation_log structure
+ try:
+ messages = self.conversation_log.get("chat", [])
+ for msg in messages:
+ ChatMessage.model_validate(msg)
+ except Exception as e:
+ raise ValidationError(f"Invalid conversation_log format: {str(e)}")
-class PublicConversation(BaseModel):
+ def save(self, *args, **kwargs):
+ self.clean()
+ super().save(*args, **kwargs)
+
+ @property
+ def messages(self) -> List[ChatMessage]:
+ """Type-hinted accessor for conversation messages"""
+ validated_messages = []
+ for msg in self.conversation_log.get("chat", []):
+ try:
+ # Clean up inferred queries if they contain None
+ if msg.get("intent") and msg["intent"].get("inferred-queries"):
+ msg["intent"]["inferred-queries"] = [
+ q for q in msg["intent"]["inferred-queries"] if q is not None and isinstance(q, str)
+ ]
+ msg["message"] = str(msg.get("message", ""))
+ validated_messages.append(ChatMessage.model_validate(msg))
+ except ValidationError as e:
+ logger.warning(f"Skipping invalid message in conversation: {e}")
+ continue
+ return validated_messages
+
+
+class PublicConversation(DbBaseModel):
source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)
slug = models.CharField(max_length=200, default=None, null=True, blank=True)
@@ -499,12 +632,12 @@ def verify_public_conversation(sender, instance, **kwargs):
instance.slug = slug
-class ReflectiveQuestion(BaseModel):
+class ReflectiveQuestion(DbBaseModel):
question = models.CharField(max_length=500)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
-class Entry(BaseModel):
+class Entry(DbBaseModel):
class EntryType(models.TextChoices):
IMAGE = "image"
PDF = "pdf"
@@ -541,7 +674,7 @@ class Entry(BaseModel):
raise ValidationError("An Entry cannot be associated with both a user and an agent.")
-class FileObject(BaseModel):
+class FileObject(DbBaseModel):
# Same as Entry but raw will be a much larger string
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
raw_text = models.TextField()
@@ -549,7 +682,7 @@ class FileObject(BaseModel):
agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
-class EntryDates(BaseModel):
+class EntryDates(DbBaseModel):
date = models.DateField()
entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")
@@ -559,12 +692,12 @@ class EntryDates(BaseModel):
]
-class UserRequests(BaseModel):
+class UserRequests(DbBaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
slug = models.CharField(max_length=200)
-class DataStore(BaseModel):
+class DataStore(DbBaseModel):
key = models.CharField(max_length=200, unique=True)
value = models.JSONField(default=dict)
private = models.BooleanField(default=False)
diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py
index 65e28d21..a4dc22ed 100644
--- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py
+++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py
@@ -2,7 +2,7 @@ import json
import logging
import re
from datetime import datetime, timedelta
-from typing import Dict, Optional
+from typing import Dict, List, Optional
from langchain.schema import ChatMessage
@@ -158,6 +158,10 @@ def converse_anthropic(
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
+ generated_images: Optional[list[str]] = None,
+ generated_files: List[str] = None,
+ generated_excalidraw_diagram: Optional[str] = None,
+ additional_context: Optional[str] = None,
tracer: dict = {},
):
"""
@@ -218,6 +222,10 @@ def converse_anthropic(
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
query_files=query_files,
+ generated_excalidraw_diagram=generated_excalidraw_diagram,
+ generated_files=generated_files,
+ generated_images=generated_images,
+ additional_program_context=additional_context,
)
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py
index 965d3010..221f44d2 100644
--- a/src/khoj/processor/conversation/google/gemini_chat.py
+++ b/src/khoj/processor/conversation/google/gemini_chat.py
@@ -2,7 +2,7 @@ import json
import logging
import re
from datetime import datetime, timedelta
-from typing import Dict, Optional
+from typing import Dict, List, Optional
from langchain.schema import ChatMessage
@@ -168,6 +168,10 @@ def converse_gemini(
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
+ generated_images: Optional[list[str]] = None,
+ generated_files: List[str] = None,
+ generated_excalidraw_diagram: Optional[str] = None,
+ additional_context: List[str] = None,
tracer={},
):
"""
@@ -229,6 +233,10 @@ def converse_gemini(
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE,
query_files=query_files,
+ generated_excalidraw_diagram=generated_excalidraw_diagram,
+ generated_files=generated_files,
+ generated_images=generated_images,
+ additional_program_context=additional_context,
)
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py
index b1ab77fe..f1746289 100644
--- a/src/khoj/processor/conversation/offline/chat_model.py
+++ b/src/khoj/processor/conversation/offline/chat_model.py
@@ -162,6 +162,8 @@ def converse_offline(
user_name: str = None,
agent: Agent = None,
query_files: str = None,
+ generated_files: List[str] = None,
+ additional_context: List[str] = None,
tracer: dict = {},
) -> Union[ThreadedGenerator, Iterator[str]]:
"""
@@ -229,6 +231,8 @@ def converse_offline(
tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.OFFLINE,
query_files=query_files,
+ generated_files=generated_files,
+ additional_program_context=additional_context,
)
logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}")
diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py
index e525fa75..ef3c3952 100644
--- a/src/khoj/processor/conversation/openai/gpt.py
+++ b/src/khoj/processor/conversation/openai/gpt.py
@@ -1,7 +1,7 @@
import json
import logging
from datetime import datetime, timedelta
-from typing import Dict, Optional
+from typing import Dict, List, Optional
from langchain.schema import ChatMessage
@@ -157,6 +157,10 @@ def converse(
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
+ generated_images: Optional[list[str]] = None,
+ generated_files: List[str] = None,
+ generated_excalidraw_diagram: Optional[str] = None,
+ additional_context: List[str] = None,
tracer: dict = {},
):
"""
@@ -219,6 +223,10 @@ def converse(
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI,
query_files=query_files,
+ generated_excalidraw_diagram=generated_excalidraw_diagram,
+ generated_files=generated_files,
+ generated_images=generated_images,
+ additional_program_context=additional_context,
)
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py
index 423ec396..d7751944 100644
--- a/src/khoj/processor/conversation/prompts.py
+++ b/src/khoj/processor/conversation/prompts.py
@@ -180,6 +180,20 @@ Improved Prompt:
""".strip()
)
+generated_image_attachment = PromptTemplate.from_template(
+ f"""
+Here is the image you generated based on my query. You can follow-up with a general response to my query. Limit to 1-2 sentences.
+""".strip()
+)
+
+generated_diagram_attachment = PromptTemplate.from_template(
+ f"""
+The AI has successfully created a diagram based on the user's query and handled the request. Good job!
+
+AI can follow-up with a general response or summary. Limit to 1-2 sentences.
+""".strip()
+)
+
## Diagram Generation
## --
@@ -1031,6 +1045,13 @@ A:
""".strip()
)
+additional_program_context = PromptTemplate.from_template(
+ """
+Here's some additional context about what happened while I was executing this query:
+{context}
+ """.strip()
+)
+
personality_prompt_safety_expert_lax = PromptTemplate.from_template(
"""
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index 21a95a29..57fc0f12 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -155,6 +155,9 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: [generated image redacted for space]\n"
+ elif chat["by"] == "khoj" and chat.get("images"):
+ chat_history += f"User: {chat['intent']['query']}\n"
+ chat_history += f"{agent_name}: [generated image redacted for space]\n"
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
@@ -211,6 +214,7 @@ class ChatEvent(Enum):
END_LLM_RESPONSE = "end_llm_response"
MESSAGE = "message"
REFERENCES = "references"
+ GENERATED_ASSETS = "generated_assets"
STATUS = "status"
METADATA = "metadata"
USAGE = "usage"
@@ -223,7 +227,6 @@ def message_to_log(
user_message_metadata={},
khoj_message_metadata={},
conversation_log=[],
- train_of_thought=[],
):
"""Create json logs from messages, metadata for conversation log"""
default_khoj_message_metadata = {
@@ -232,6 +235,10 @@ def message_to_log(
}
khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ # Filter out any fields that are set to None
+ user_message_metadata = {k: v for k, v in user_message_metadata.items() if v is not None}
+ khoj_message_metadata = {k: v for k, v in khoj_message_metadata.items() if v is not None}
+
# Create json log from Human's message
human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
@@ -259,6 +266,9 @@ def save_to_conversation_log(
automation_id: str = None,
query_images: List[str] = None,
raw_query_files: List[FileAttachment] = [],
+ generated_images: List[str] = [],
+ raw_generated_files: List[FileAttachment] = [],
+ generated_excalidraw_diagram: str = None,
train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {},
):
@@ -281,9 +291,11 @@ def save_to_conversation_log(
"automationId": automation_id,
"trainOfThought": train_of_thought,
"turnId": turn_id,
+ "images": generated_images,
+ "queryFiles": [file.model_dump(mode="json") for file in raw_generated_files],
+ "excalidrawDiagram": str(generated_excalidraw_diagram),
},
conversation_log=meta_log.get("chat", []),
- train_of_thought=train_of_thought,
)
ConversationAdapters.save_conversation(
user,
@@ -307,7 +319,7 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
def construct_structured_message(
- message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str
+ message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str = None
):
"""
Format messages into appropriate multimedia format for supported chat model types
@@ -363,6 +375,10 @@ def generate_chatml_messages_with_context(
model_type="",
context_message="",
query_files: str = None,
+ generated_images: Optional[list[str]] = None,
+ generated_files: List[FileAttachment] = None,
+ generated_excalidraw_diagram: str = None,
+ additional_program_context: List[str] = [],
):
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
# Set max prompt size from user config or based on pre-configured for model and machine specs
@@ -382,6 +398,7 @@ def generate_chatml_messages_with_context(
message_attached_files = ""
chat_message = chat.get("message")
+ role = "user" if chat["by"] == "you" else "assistant"
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
chat_message = chat["intent"].get("inferred-queries")[0]
@@ -402,7 +419,7 @@ def generate_chatml_messages_with_context(
query_files_dict[file["name"]] = file["content"]
message_attached_files = gather_raw_query_files(query_files_dict)
- chatml_messages.append(ChatMessage(content=message_attached_files, role="user"))
+ chatml_messages.append(ChatMessage(content=message_attached_files, role=role))
if not is_none_or_empty(chat.get("onlineContext")):
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
@@ -411,9 +428,18 @@ def generate_chatml_messages_with_context(
reconstructed_context_message = ChatMessage(content=message_context, role="user")
chatml_messages.insert(0, reconstructed_context_message)
- role = "user" if chat["by"] == "you" else "assistant"
+ if chat.get("images") and role == "assistant":
+ # Issue: the assistant role cannot accept an image as a message content, so send it in a separate user message.
+ file_attachment_message = construct_structured_message(
+ message=prompts.generated_image_attachment.format(),
+ images=chat.get("images"),
+ model_type=model_type,
+ vision_enabled=vision_enabled,
+ )
+ chatml_messages.append(ChatMessage(content=file_attachment_message, role="user"))
+
message_content = construct_structured_message(
- chat_message, chat.get("images"), model_type, vision_enabled, attached_file_context=query_files
+ chat_message, chat.get("images") if role == "user" else [], model_type, vision_enabled
)
reconstructed_message = ChatMessage(content=message_content, role=role)
@@ -423,6 +449,7 @@ def generate_chatml_messages_with_context(
break
messages = []
+
if not is_none_or_empty(user_message):
messages.append(
ChatMessage(
@@ -435,6 +462,31 @@ def generate_chatml_messages_with_context(
if not is_none_or_empty(context_message):
messages.append(ChatMessage(content=context_message, role="user"))
+ if generated_images:
+ messages.append(
+ ChatMessage(
+ content=construct_structured_message(
+ prompts.generated_image_attachment.format(), generated_images, model_type, vision_enabled
+ ),
+ role="user",
+ )
+ )
+
+ if generated_files:
+ message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
+ messages.append(ChatMessage(content=message_attached_files, role="assistant"))
+
+ if generated_excalidraw_diagram:
+ messages.append(ChatMessage(content=prompts.generated_diagram_attachment.format(), role="assistant"))
+
+ if additional_program_context:
+ messages.append(
+ ChatMessage(
+ content=prompts.additional_program_context.format(context="\n".join(additional_program_context)),
+ role="assistant",
+ )
+ )
+
if len(chatml_messages) > 0:
messages += chatml_messages
diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py
index 6c1f71b6..1bec7f41 100644
--- a/src/khoj/processor/image/generate.py
+++ b/src/khoj/processor/image/generate.py
@@ -12,7 +12,7 @@ from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
from khoj.routers.storage import upload_image
from khoj.utils import state
-from khoj.utils.helpers import ImageIntentType, convert_image_to_webp, timer
+from khoj.utils.helpers import convert_image_to_webp, timer
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
@@ -34,14 +34,13 @@ async def text_to_image(
status_code = 200
image = None
image_url = None
- intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
- yield image_url or image, status_code, message, intent_type.value
+ yield image_url or image, status_code, message
return
text2image_model = text_to_image_config.model_name
@@ -53,6 +52,9 @@ async def text_to_image(
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
+ elif chat["by"] == "khoj" and chat.get("images"):
+ chat_history += f"Q: {chat['intent']['query']}\n"
+ chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
if send_status_func:
async for event in send_status_func("**Enhancing the Painting Prompt**"):
@@ -92,31 +94,29 @@ async def text_to_image(
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore
- yield image_url or image, status_code, message, intent_type.value
+ yield image_url or image, status_code, message
return
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed using OpenAI" # type: ignore
status_code = e.status_code # type: ignore
- yield image_url or image, status_code, message, intent_type.value
+ yield image_url or image, status_code, message
return
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error."
status_code = 502
- yield image_url or image, status_code, message, intent_type.value
+ yield image_url or image, status_code, message
return
# Decide how to store the generated image
with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid)
- if image_url:
- intent_type = ImageIntentType.TEXT_TO_IMAGE2
- else:
- intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
+
+ if not image_url:
image = base64.b64encode(webp_image_bytes).decode("utf-8")
- yield image_url or image, status_code, image_prompt, intent_type.value
+ yield image_url or image, status_code, image_prompt
def generate_image_with_openai(
diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py
index d0833eec..7d2dc3a6 100644
--- a/src/khoj/routers/api_chat.py
+++ b/src/khoj/routers/api_chat.py
@@ -77,6 +77,7 @@ from khoj.utils.helpers import (
)
from khoj.utils.rawconfig import (
ChatRequestBody,
+ FileAttachment,
FileFilterRequest,
FilesFilterRequest,
LocationData,
@@ -771,6 +772,11 @@ async def chat(
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
attached_file_context = gather_raw_query_files(query_files)
+ generated_images: List[str] = []
+ generated_files: List[FileAttachment] = []
+ generated_excalidraw_diagram: str = None
+ additional_context_for_llm_response: List[str] = []
+
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
chosen_io = await aget_data_sources_and_output_format(
q,
@@ -876,21 +882,17 @@ async def chat(
async for result in send_llm_response(response, tracer.get("usage")):
yield result
- await sync_to_async(save_to_conversation_log)(
- q,
- response_log,
- user,
- meta_log,
- user_message_time,
- intent_type="summarize",
- client_application=request.user.client_app,
- conversation_id=conversation_id,
- query_images=uploaded_images,
- train_of_thought=train_of_thought,
- raw_query_files=raw_query_files,
- tracer=tracer,
+ summarized_document = FileAttachment(
+ name="Summarized Document",
+ content=response_log,
+ type="text/plain",
+ size=len(response_log.encode("utf-8")),
)
- return
+
+ async for result in send_event(ChatEvent.GENERATED_ASSETS, {"files": [summarized_document.model_dump()]}):
+ yield result
+
+ generated_files.append(summarized_document)
custom_filters = []
if conversation_commands == [ConversationCommand.Help]:
@@ -1079,6 +1081,7 @@ async def chat(
async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"):
yield result
except ValueError as e:
+ additional_context_for_llm_response.append(f"Failed to run code")
logger.warning(
f"Failed to use code tool: {e}. Attempting to respond without code results",
exc_info=True,
@@ -1116,51 +1119,36 @@ async def chat(
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
- generated_image, status_code, improved_image_prompt, intent_type = result
+ generated_image, status_code, improved_image_prompt = result
+ inferred_queries.append(improved_image_prompt)
if generated_image is None or status_code != 200:
- content_obj = {
- "content-type": "application/json",
- "intentType": intent_type,
- "detail": improved_image_prompt,
- "image": None,
- }
- async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
+ additional_context_for_llm_response.append(f"Failed to generate image with {improved_image_prompt}")
+ async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"):
yield result
- return
+ else:
+ generated_images.append(generated_image)
+ # content_obj = {
+ # "intentType": intent_type,
+ # "inferredQueries": [improved_image_prompt],
+ # "image": generated_image,
+ # }
+ # async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
+ # yield result
+ # return
- await sync_to_async(save_to_conversation_log)(
- q,
- generated_image,
- user,
- meta_log,
- user_message_time,
- intent_type=intent_type,
- inferred_queries=[improved_image_prompt],
- client_application=request.user.client_app,
- conversation_id=conversation_id,
- compiled_references=compiled_references,
- online_results=online_results,
- code_results=code_results,
- query_images=uploaded_images,
- train_of_thought=train_of_thought,
- raw_query_files=raw_query_files,
- tracer=tracer,
- )
- content_obj = {
- "intentType": intent_type,
- "inferredQueries": [improved_image_prompt],
- "image": generated_image,
- }
- async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
- yield result
- return
+ async for result in send_event(
+ ChatEvent.GENERATED_ASSETS,
+ {
+ "images": [generated_image],
+ },
+ ):
+ yield result
if ConversationCommand.Diagram in conversation_commands:
async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
yield result
- intent_type = "excalidraw"
inferred_queries = []
diagram_description = ""
@@ -1184,62 +1172,59 @@ async def chat(
if better_diagram_description_prompt and excalidraw_diagram_description:
inferred_queries.append(better_diagram_description_prompt)
diagram_description = excalidraw_diagram_description
+
+ generated_excalidraw_diagram = diagram_description
+
+ async for result in send_event(
+ ChatEvent.GENERATED_ASSETS,
+ {
+ "excalidrawDiagram": excalidraw_diagram_description,
+ },
+ ):
+ yield result
else:
error_message = "Failed to generate diagram. Please try again later."
- async for result in send_llm_response(error_message, tracer.get("usage")):
+ additional_context_for_llm_response.append(
+ f"AI attempted to programmatically generate a diagram but failed due to a program issue. Generally, it is able to do so, but encountered a system issue this time. AI can suggest text description or rendering of the diagram or user can try again with a simpler prompt."
+ )
+
+ async for result in send_event(ChatEvent.STATUS, error_message):
yield result
- await sync_to_async(save_to_conversation_log)(
- q,
- error_message,
- user,
- meta_log,
- user_message_time,
- inferred_queries=[better_diagram_description_prompt],
- client_application=request.user.client_app,
- conversation_id=conversation_id,
- compiled_references=compiled_references,
- online_results=online_results,
- code_results=code_results,
- query_images=uploaded_images,
- train_of_thought=train_of_thought,
- raw_query_files=raw_query_files,
- tracer=tracer,
- )
- return
+ # content_obj = {
+ # "intentType": intent_type,
+ # "inferredQueries": inferred_queries,
+ # "image": diagram_description,
+ # }
- content_obj = {
- "intentType": intent_type,
- "inferredQueries": inferred_queries,
- "image": diagram_description,
- }
+ # await sync_to_async(save_to_conversation_log)(
+ # q,
+ # excalidraw_diagram_description,
+ # user,
+ # meta_log,
+ # user_message_time,
+ # intent_type="excalidraw",
+ # inferred_queries=[better_diagram_description_prompt],
+ # client_application=request.user.client_app,
+ # conversation_id=conversation_id,
+ # compiled_references=compiled_references,
+ # online_results=online_results,
+ # code_results=code_results,
+ # query_images=uploaded_images,
+ # train_of_thought=train_of_thought,
+ # raw_query_files=raw_query_files,
+ # generated_images=generated_images,
+ # tracer=tracer,
+ # )
- await sync_to_async(save_to_conversation_log)(
- q,
- excalidraw_diagram_description,
- user,
- meta_log,
- user_message_time,
- intent_type="excalidraw",
- inferred_queries=[better_diagram_description_prompt],
- client_application=request.user.client_app,
- conversation_id=conversation_id,
- compiled_references=compiled_references,
- online_results=online_results,
- code_results=code_results,
- query_images=uploaded_images,
- train_of_thought=train_of_thought,
- raw_query_files=raw_query_files,
- tracer=tracer,
- )
-
- async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
- yield result
- return
+ # async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
+ # yield result
+ # return
## Generate Text Output
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
yield result
+
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
@@ -1259,6 +1244,10 @@ async def chat(
train_of_thought,
attached_file_context,
raw_query_files,
+ generated_images,
+ generated_files,
+ generated_excalidraw_diagram,
+ additional_context_for_llm_response,
tracer,
)
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index e160b8e3..30deca0c 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -1184,6 +1184,10 @@ def generate_chat_response(
train_of_thought: List[Any] = [],
query_files: str = None,
raw_query_files: List[FileAttachment] = None,
+ generated_images: List[str] = None,
+ raw_generated_files: List[FileAttachment] = [],
+ generated_excalidraw_diagram: str = None,
+ additional_context: List[str] = [],
tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
@@ -1207,6 +1211,9 @@ def generate_chat_response(
query_images=query_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
+ generated_images=generated_images,
+ raw_generated_files=raw_generated_files,
+ generated_excalidraw_diagram=generated_excalidraw_diagram,
tracer=tracer,
)
@@ -1242,6 +1249,7 @@ def generate_chat_response(
user_name=user_name,
agent=agent,
query_files=query_files,
+ generated_files=raw_generated_files,
tracer=tracer,
)
@@ -1268,6 +1276,10 @@ def generate_chat_response(
agent=agent,
vision_available=vision_available,
query_files=query_files,
+ generated_files=raw_generated_files,
+ generated_images=generated_images,
+ generated_excalidraw_diagram=generated_excalidraw_diagram,
+ additional_context=additional_context,
tracer=tracer,
)
@@ -1291,6 +1303,10 @@ def generate_chat_response(
agent=agent,
vision_available=vision_available,
query_files=query_files,
+ generated_files=raw_generated_files,
+ generated_images=generated_images,
+ generated_excalidraw_diagram=generated_excalidraw_diagram,
+ additional_context=additional_context,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
@@ -1313,6 +1329,10 @@ def generate_chat_response(
query_images=query_images,
vision_available=vision_available,
query_files=query_files,
+ generated_files=raw_generated_files,
+ generated_images=generated_images,
+ generated_excalidraw_diagram=generated_excalidraw_diagram,
+ additional_context=additional_context,
tracer=tracer,
)
@@ -1784,6 +1804,9 @@ class MessageProcessor:
self.references = {}
self.usage = {}
self.raw_response = ""
+ self.generated_images = []
+ self.generated_files = []
+ self.generated_excalidraw_diagrams = []
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
if raw_chunk.startswith("{") and raw_chunk.endswith("}"):
@@ -1822,6 +1845,16 @@ class MessageProcessor:
self.raw_response += chunk_data
else:
self.raw_response += chunk_data
+ elif chunk_type == ChatEvent.GENERATED_ASSETS:
+ chunk_data = chunk["data"]
+ if isinstance(chunk_data, dict):
+ for key in chunk_data:
+ if key == "images":
+ self.generated_images = chunk_data[key]
+ elif key == "files":
+ self.generated_files = chunk_data[key]
+ elif key == "excalidraw_diagrams":
+ self.generated_excalidraw_diagrams = chunk_data[key]
def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]:
if "image" in json_data or "details" in json_data:
@@ -1852,7 +1885,14 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
if buffer:
processor.process_message_chunk(buffer)
- return {"response": processor.raw_response, "references": processor.references, "usage": processor.usage}
+ return {
+ "response": processor.raw_response,
+ "references": processor.references,
+ "usage": processor.usage,
+ "images": processor.generated_images,
+ "files": processor.generated_files,
+ "excalidraw_diagrams": processor.generated_excalidraw_diagrams,
+ }
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
diff --git a/tests/test_offline_chat_actors.py b/tests/test_offline_chat_actors.py
index b404f0e8..e84612a2 100644
--- a/tests/test_offline_chat_actors.py
+++ b/tests/test_offline_chat_actors.py
@@ -22,7 +22,6 @@ from khoj.processor.conversation.offline.chat_model import (
filter_questions,
)
from khoj.processor.conversation.offline.utils import download_model
-from khoj.processor.conversation.utils import message_to_log
from khoj.utils.constants import default_offline_chat_models
diff --git a/tests/test_online_chat_director.py b/tests/test_online_chat_director.py
index 94545b4c..ea3a6e1c 100644
--- a/tests/test_online_chat_director.py
+++ b/tests/test_online_chat_director.py
@@ -6,7 +6,6 @@ from freezegun import freeze_time
from khoj.database.models import Agent, Entry, KhojUser
from khoj.processor.conversation import prompts
-from khoj.processor.conversation.utils import message_to_log
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_api_key
# Initialize variables for tests