From d91935c880629805708589dadcb6fdf5355463f0 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Thu, 28 Nov 2024 17:28:23 -0800 Subject: [PATCH 01/25] Initial commit of a functional but not yet elegant prototype for this concept --- src/interface/web/app/common/chatFunctions.ts | 23 +++ .../components/chatHistory/chatHistory.tsx | 9 + .../components/chatMessage/chatMessage.tsx | 12 ++ src/khoj/database/models/__init__.py | 195 +++++++++++++++--- .../conversation/anthropic/anthropic_chat.py | 10 +- .../conversation/google/gemini_chat.py | 10 +- .../conversation/offline/chat_model.py | 4 + src/khoj/processor/conversation/openai/gpt.py | 10 +- src/khoj/processor/conversation/prompts.py | 21 ++ src/khoj/processor/conversation/utils.py | 64 +++++- src/khoj/processor/image/generate.py | 22 +- src/khoj/routers/api_chat.py | 181 ++++++++-------- src/khoj/routers/helpers.py | 42 +++- tests/test_offline_chat_actors.py | 1 - tests/test_online_chat_director.py | 1 - 15 files changed, 455 insertions(+), 150 deletions(-) diff --git a/src/interface/web/app/common/chatFunctions.ts b/src/interface/web/app/common/chatFunctions.ts index 6585b4c9..c64e81ba 100644 --- a/src/interface/web/app/common/chatFunctions.ts +++ b/src/interface/web/app/common/chatFunctions.ts @@ -1,3 +1,4 @@ +import { AttachedFileText } from "../components/chatInputArea/chatInputArea"; import { CodeContext, Context, @@ -16,6 +17,12 @@ export interface MessageMetadata { turnId: string; } +export interface GeneratedAssetsData { + images: string[]; + excalidrawDiagram: string; + files: AttachedFileText[]; +} + export interface ResponseWithIntent { intentType: string; response: string; @@ -84,6 +91,8 @@ export function processMessageChunk( if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext, codeContext }; + console.log(`chunk type: ${chunk.type}`); + if (chunk.type === "status") { console.log(`status: ${chunk.data}`); const statusMessage = chunk.data as string; @@ -98,6 +107,20 @@ export function processMessageChunk( } else if (chunk.type === "metadata") { const messageMetadata = chunk.data as MessageMetadata; currentMessage.turnId = messageMetadata.turnId; + } else if (chunk.type === "generated_assets") { + const generatedAssets = chunk.data as GeneratedAssetsData; + + if (generatedAssets.images) { + currentMessage.generatedImages = generatedAssets.images; + } + + if (generatedAssets.excalidrawDiagram) { + currentMessage.generatedExcalidrawDiagram = generatedAssets.excalidrawDiagram; + } + + if (generatedAssets.files) { + currentMessage.generatedFiles = generatedAssets.files; + } } else if (chunk.type === "message") { const chunkData = chunk.data; // Here, handle if the response is a JSON response with an image, but the intentType is excalidraw diff --git a/src/interface/web/app/components/chatHistory/chatHistory.tsx b/src/interface/web/app/components/chatHistory/chatHistory.tsx index ea566df4..c5fd8b54 100644 --- a/src/interface/web/app/components/chatHistory/chatHistory.tsx +++ b/src/interface/web/app/components/chatHistory/chatHistory.tsx @@ -54,6 +54,12 @@ function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) { const lastIndex = props.trainOfThought.length - 1; const [collapsed, setCollapsed] = useState(props.completed); + // useEffect(() => { + // if (props.completed) { + // setCollapsed(true); + // } + // }), [props.completed]; + return (
void; conversationId: string; turnId?: string; + generatedImage?: string; + excalidrawDiagram?: string; + generatedFiles?: AttachedFileText[]; } interface TrainOfThoughtProps { @@ -394,6 +402,10 @@ const ChatMessage = forwardRef((props, ref) => setExcalidrawData(props.chatMessage.message); } + if (props.chatMessage.excalidrawDiagram) { + setExcalidrawData(props.chatMessage.excalidrawDiagram); + } + // Replace LaTeX delimiters with placeholders message = message .replace(/\\\(/g, "LEFTPAREN") diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 8f6c5c78..ae6a729d 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -1,7 +1,9 @@ +import logging import os import re import uuid from random import choice +from typing import Dict, List, Optional, Union from django.contrib.auth.models import AbstractUser from django.contrib.postgres.fields import ArrayField @@ -11,9 +13,109 @@ from django.db.models.signals import pre_save from django.dispatch import receiver from pgvector.django import VectorField from phonenumber_field.modelfields import PhoneNumberField +from pydantic import BaseModel as PydanticBaseModel +from pydantic import Field + +logger = logging.getLogger(__name__) -class BaseModel(models.Model): +# Pydantic models for type Chat Message validation +class Context(PydanticBaseModel): + compiled: str + file: str + + +class CodeContextFile(PydanticBaseModel): + filename: str + b64_data: str + + +class CodeContextResult(PydanticBaseModel): + success: bool + output_files: List[CodeContextFile] + std_out: str + std_err: str + code_runtime: int + + +class CodeContextData(PydanticBaseModel): + code: str + result: CodeContextResult + + +class WebPage(PydanticBaseModel): + link: str + query: Optional[str] = None + snippet: str + + +class AnswerBox(PydanticBaseModel): + link: str + snippet: str + title: str + snippetHighlighted: List[str] + + +class PeopleAlsoAsk(PydanticBaseModel): + link: str + question: str + snippet: str + title: str + + +class KnowledgeGraph(PydanticBaseModel): + attributes: Dict[str, str] + description: str + descriptionLink: str + descriptionSource: str + imageUrl: str + title: str + type: str + + +class OrganicContext(PydanticBaseModel): + snippet: str + title: str + link: str + + +class OnlineContext(PydanticBaseModel): + webpages: Optional[Union[WebPage, List[WebPage]]] = None + answerBox: Optional[AnswerBox] = None + peopleAlsoAsk: Optional[List[PeopleAlsoAsk]] = None + knowledgeGraph: Optional[KnowledgeGraph] = None + organicContext: Optional[List[OrganicContext]] = None + + +class Intent(PydanticBaseModel): + type: str + query: str + memory_type: str = Field(alias="memory-type") + inferred_queries: Optional[List[str]] = Field(default=None, alias="inferred-queries") + + +class TrainOfThought(PydanticBaseModel): + type: str + data: str + + +class ChatMessage(PydanticBaseModel): + message: str + trainOfThought: List[TrainOfThought] = [] + context: List[Context] = [] + onlineContext: Dict[str, OnlineContext] = {} + codeContext: Dict[str, CodeContextData] = {} + created: str + images: Optional[List[str]] = None + queryFiles: Optional[List[Dict]] = None + excalidrawDiagram: Optional[str] = None + by: str + turnId: Optional[str] + intent: Optional[Intent] = None + automationId: Optional[str] = None + + +class DbBaseModel(models.Model): created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) @@ -21,7 +123,7 @@ class BaseModel(models.Model): abstract = True -class ClientApplication(BaseModel): +class ClientApplication(DbBaseModel): name = models.CharField(max_length=200) client_id = models.CharField(max_length=200) client_secret = models.CharField(max_length=200) @@ -67,7 +169,7 @@ class KhojApiUser(models.Model): accessed_at = models.DateTimeField(null=True, default=None) -class Subscription(BaseModel): +class Subscription(DbBaseModel): class Type(models.TextChoices): TRIAL = "trial" STANDARD = "standard" @@ -79,13 +181,13 @@ class Subscription(BaseModel): enabled_trial_at = models.DateTimeField(null=True, default=None, blank=True) -class OpenAIProcessorConversationConfig(BaseModel): +class OpenAIProcessorConversationConfig(DbBaseModel): name = models.CharField(max_length=200) api_key = models.CharField(max_length=200) api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True) -class ChatModelOptions(BaseModel): +class ChatModelOptions(DbBaseModel): class ModelType(models.TextChoices): OPENAI = "openai" OFFLINE = "offline" @@ -103,12 +205,12 @@ class ChatModelOptions(BaseModel): ) -class VoiceModelOption(BaseModel): +class VoiceModelOption(DbBaseModel): model_id = models.CharField(max_length=200) name = models.CharField(max_length=200) -class Agent(BaseModel): +class Agent(DbBaseModel): class StyleColorTypes(models.TextChoices): BLUE = "blue" GREEN = "green" @@ -208,7 +310,7 @@ class Agent(BaseModel): super().save(*args, **kwargs) -class ProcessLock(BaseModel): +class ProcessLock(DbBaseModel): class Operation(models.TextChoices): INDEX_CONTENT = "index_content" SCHEDULED_JOB = "scheduled_job" @@ -231,24 +333,24 @@ def verify_agent(sender, instance, **kwargs): raise ValidationError(f"A private Agent with the name {instance.name} already exists.") -class NotionConfig(BaseModel): +class NotionConfig(DbBaseModel): token = models.CharField(max_length=200) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) -class GithubConfig(BaseModel): +class GithubConfig(DbBaseModel): pat_token = models.CharField(max_length=200) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) -class GithubRepoConfig(BaseModel): +class GithubRepoConfig(DbBaseModel): name = models.CharField(max_length=200) owner = models.CharField(max_length=200) branch = models.CharField(max_length=200) github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig") -class WebScraper(BaseModel): +class WebScraper(DbBaseModel): class WebScraperType(models.TextChoices): FIRECRAWL = "Firecrawl" OLOSTEP = "Olostep" @@ -321,7 +423,7 @@ class WebScraper(BaseModel): super().save(*args, **kwargs) -class ServerChatSettings(BaseModel): +class ServerChatSettings(DbBaseModel): chat_default = models.ForeignKey( ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default" ) @@ -333,35 +435,35 @@ class ServerChatSettings(BaseModel): ) -class LocalOrgConfig(BaseModel): +class LocalOrgConfig(DbBaseModel): input_files = models.JSONField(default=list, null=True) input_filter = models.JSONField(default=list, null=True) index_heading_entries = models.BooleanField(default=False) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) -class LocalMarkdownConfig(BaseModel): +class LocalMarkdownConfig(DbBaseModel): input_files = models.JSONField(default=list, null=True) input_filter = models.JSONField(default=list, null=True) index_heading_entries = models.BooleanField(default=False) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) -class LocalPdfConfig(BaseModel): +class LocalPdfConfig(DbBaseModel): input_files = models.JSONField(default=list, null=True) input_filter = models.JSONField(default=list, null=True) index_heading_entries = models.BooleanField(default=False) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) -class LocalPlaintextConfig(BaseModel): +class LocalPlaintextConfig(DbBaseModel): input_files = models.JSONField(default=list, null=True) input_filter = models.JSONField(default=list, null=True) index_heading_entries = models.BooleanField(default=False) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) -class SearchModelConfig(BaseModel): +class SearchModelConfig(DbBaseModel): class ModelType(models.TextChoices): TEXT = "text" @@ -393,7 +495,7 @@ class SearchModelConfig(BaseModel): bi_encoder_confidence_threshold = models.FloatField(default=0.18) -class TextToImageModelConfig(BaseModel): +class TextToImageModelConfig(DbBaseModel): class ModelType(models.TextChoices): OPENAI = "openai" STABILITYAI = "stability-ai" @@ -430,7 +532,7 @@ class TextToImageModelConfig(BaseModel): super().save(*args, **kwargs) -class SpeechToTextModelOptions(BaseModel): +class SpeechToTextModelOptions(DbBaseModel): class ModelType(models.TextChoices): OPENAI = "openai" OFFLINE = "offline" @@ -439,22 +541,22 @@ class SpeechToTextModelOptions(BaseModel): model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) -class UserConversationConfig(BaseModel): +class UserConversationConfig(DbBaseModel): user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True) -class UserVoiceModelConfig(BaseModel): +class UserVoiceModelConfig(DbBaseModel): user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True) -class UserTextToImageModelConfig(BaseModel): +class UserTextToImageModelConfig(DbBaseModel): user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE) -class Conversation(BaseModel): +class Conversation(DbBaseModel): user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) conversation_log = models.JSONField(default=dict) client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True) @@ -468,8 +570,39 @@ class Conversation(BaseModel): file_filters = models.JSONField(default=list) id = models.UUIDField(default=uuid.uuid4, editable=False, unique=True, primary_key=True, db_index=True) + def clean(self): + # Validate conversation_log structure + try: + messages = self.conversation_log.get("chat", []) + for msg in messages: + ChatMessage.model_validate(msg) + except Exception as e: + raise ValidationError(f"Invalid conversation_log format: {str(e)}") -class PublicConversation(BaseModel): + def save(self, *args, **kwargs): + self.clean() + super().save(*args, **kwargs) + + @property + def messages(self) -> List[ChatMessage]: + """Type-hinted accessor for conversation messages""" + validated_messages = [] + for msg in self.conversation_log.get("chat", []): + try: + # Clean up inferred queries if they contain None + if msg.get("intent") and msg["intent"].get("inferred-queries"): + msg["intent"]["inferred-queries"] = [ + q for q in msg["intent"]["inferred-queries"] if q is not None and isinstance(q, str) + ] + msg["message"] = str(msg.get("message", "")) + validated_messages.append(ChatMessage.model_validate(msg)) + except ValidationError as e: + logger.warning(f"Skipping invalid message in conversation: {e}") + continue + return validated_messages + + +class PublicConversation(DbBaseModel): source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE) conversation_log = models.JSONField(default=dict) slug = models.CharField(max_length=200, default=None, null=True, blank=True) @@ -499,12 +632,12 @@ def verify_public_conversation(sender, instance, **kwargs): instance.slug = slug -class ReflectiveQuestion(BaseModel): +class ReflectiveQuestion(DbBaseModel): question = models.CharField(max_length=500) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True) -class Entry(BaseModel): +class Entry(DbBaseModel): class EntryType(models.TextChoices): IMAGE = "image" PDF = "pdf" @@ -541,7 +674,7 @@ class Entry(BaseModel): raise ValidationError("An Entry cannot be associated with both a user and an agent.") -class FileObject(BaseModel): +class FileObject(DbBaseModel): # Same as Entry but raw will be a much larger string file_name = models.CharField(max_length=400, default=None, null=True, blank=True) raw_text = models.TextField() @@ -549,7 +682,7 @@ class FileObject(BaseModel): agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True) -class EntryDates(BaseModel): +class EntryDates(DbBaseModel): date = models.DateField() entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates") @@ -559,12 +692,12 @@ class EntryDates(BaseModel): ] -class UserRequests(BaseModel): +class UserRequests(DbBaseModel): user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) slug = models.CharField(max_length=200) -class DataStore(BaseModel): +class DataStore(DbBaseModel): key = models.CharField(max_length=200, unique=True) value = models.JSONField(default=dict) private = models.BooleanField(default=False) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 65e28d21..a4dc22ed 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -2,7 +2,7 @@ import json import logging import re from datetime import datetime, timedelta -from typing import Dict, Optional +from typing import Dict, List, Optional from langchain.schema import ChatMessage @@ -158,6 +158,10 @@ def converse_anthropic( query_images: Optional[list[str]] = None, vision_available: bool = False, query_files: str = None, + generated_images: Optional[list[str]] = None, + generated_files: List[str] = None, + generated_excalidraw_diagram: Optional[str] = None, + additional_context: Optional[str] = None, tracer: dict = {}, ): """ @@ -218,6 +222,10 @@ def converse_anthropic( vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.ANTHROPIC, query_files=query_files, + generated_excalidraw_diagram=generated_excalidraw_diagram, + generated_files=generated_files, + generated_images=generated_images, + additional_program_context=additional_context, ) messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 965d3010..221f44d2 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -2,7 +2,7 @@ import json import logging import re from datetime import datetime, timedelta -from typing import Dict, Optional +from typing import Dict, List, Optional from langchain.schema import ChatMessage @@ -168,6 +168,10 @@ def converse_gemini( query_images: Optional[list[str]] = None, vision_available: bool = False, query_files: str = None, + generated_images: Optional[list[str]] = None, + generated_files: List[str] = None, + generated_excalidraw_diagram: Optional[str] = None, + additional_context: List[str] = None, tracer={}, ): """ @@ -229,6 +233,10 @@ def converse_gemini( vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.GOOGLE, query_files=query_files, + generated_excalidraw_diagram=generated_excalidraw_diagram, + generated_files=generated_files, + generated_images=generated_images, + additional_program_context=additional_context, ) messages, system_prompt = format_messages_for_gemini(messages, system_prompt) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index b1ab77fe..f1746289 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -162,6 +162,8 @@ def converse_offline( user_name: str = None, agent: Agent = None, query_files: str = None, + generated_files: List[str] = None, + additional_context: List[str] = None, tracer: dict = {}, ) -> Union[ThreadedGenerator, Iterator[str]]: """ @@ -229,6 +231,8 @@ def converse_offline( tokenizer_name=tokenizer_name, model_type=ChatModelOptions.ModelType.OFFLINE, query_files=query_files, + generated_files=generated_files, + additional_program_context=additional_context, ) logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}") diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index e525fa75..ef3c3952 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -1,7 +1,7 @@ import json import logging from datetime import datetime, timedelta -from typing import Dict, Optional +from typing import Dict, List, Optional from langchain.schema import ChatMessage @@ -157,6 +157,10 @@ def converse( query_images: Optional[list[str]] = None, vision_available: bool = False, query_files: str = None, + generated_images: Optional[list[str]] = None, + generated_files: List[str] = None, + generated_excalidraw_diagram: Optional[str] = None, + additional_context: List[str] = None, tracer: dict = {}, ): """ @@ -219,6 +223,10 @@ def converse( vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.OPENAI, query_files=query_files, + generated_excalidraw_diagram=generated_excalidraw_diagram, + generated_files=generated_files, + generated_images=generated_images, + additional_program_context=additional_context, ) logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 423ec396..d7751944 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -180,6 +180,20 @@ Improved Prompt: """.strip() ) +generated_image_attachment = PromptTemplate.from_template( + f""" +Here is the image you generated based on my query. You can follow-up with a general response to my query. Limit to 1-2 sentences. +""".strip() +) + +generated_diagram_attachment = PromptTemplate.from_template( + f""" +The AI has successfully created a diagram based on the user's query and handled the request. Good job! + +AI can follow-up with a general response or summary. Limit to 1-2 sentences. +""".strip() +) + ## Diagram Generation ## -- @@ -1031,6 +1045,13 @@ A: """.strip() ) +additional_program_context = PromptTemplate.from_template( + """ +Here's some additional context about what happened while I was executing this query: +{context} + """.strip() +) + personality_prompt_safety_expert_lax = PromptTemplate.from_template( """ diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 21a95a29..57fc0f12 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -155,6 +155,9 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")): chat_history += f"User: {chat['intent']['query']}\n" chat_history += f"{agent_name}: [generated image redacted for space]\n" + elif chat["by"] == "khoj" and chat.get("images"): + chat_history += f"User: {chat['intent']['query']}\n" + chat_history += f"{agent_name}: [generated image redacted for space]\n" elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")): chat_history += f"User: {chat['intent']['query']}\n" chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n" @@ -211,6 +214,7 @@ class ChatEvent(Enum): END_LLM_RESPONSE = "end_llm_response" MESSAGE = "message" REFERENCES = "references" + GENERATED_ASSETS = "generated_assets" STATUS = "status" METADATA = "metadata" USAGE = "usage" @@ -223,7 +227,6 @@ def message_to_log( user_message_metadata={}, khoj_message_metadata={}, conversation_log=[], - train_of_thought=[], ): """Create json logs from messages, metadata for conversation log""" default_khoj_message_metadata = { @@ -232,6 +235,10 @@ def message_to_log( } khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + # Filter out any fields that are set to None + user_message_metadata = {k: v for k, v in user_message_metadata.items() if v is not None} + khoj_message_metadata = {k: v for k, v in khoj_message_metadata.items() if v is not None} + # Create json log from Human's message human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata) @@ -259,6 +266,9 @@ def save_to_conversation_log( automation_id: str = None, query_images: List[str] = None, raw_query_files: List[FileAttachment] = [], + generated_images: List[str] = [], + raw_generated_files: List[FileAttachment] = [], + generated_excalidraw_diagram: str = None, train_of_thought: List[Any] = [], tracer: Dict[str, Any] = {}, ): @@ -281,9 +291,11 @@ def save_to_conversation_log( "automationId": automation_id, "trainOfThought": train_of_thought, "turnId": turn_id, + "images": generated_images, + "queryFiles": [file.model_dump(mode="json") for file in raw_generated_files], + "excalidrawDiagram": str(generated_excalidraw_diagram), }, conversation_log=meta_log.get("chat", []), - train_of_thought=train_of_thought, ) ConversationAdapters.save_conversation( user, @@ -307,7 +319,7 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response} def construct_structured_message( - message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str + message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str = None ): """ Format messages into appropriate multimedia format for supported chat model types @@ -363,6 +375,10 @@ def generate_chatml_messages_with_context( model_type="", context_message="", query_files: str = None, + generated_images: Optional[list[str]] = None, + generated_files: List[FileAttachment] = None, + generated_excalidraw_diagram: str = None, + additional_program_context: List[str] = [], ): """Generate chat messages with appropriate context from previous conversation to send to the chat model""" # Set max prompt size from user config or based on pre-configured for model and machine specs @@ -382,6 +398,7 @@ def generate_chatml_messages_with_context( message_attached_files = "" chat_message = chat.get("message") + role = "user" if chat["by"] == "you" else "assistant" if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""): chat_message = chat["intent"].get("inferred-queries")[0] @@ -402,7 +419,7 @@ def generate_chatml_messages_with_context( query_files_dict[file["name"]] = file["content"] message_attached_files = gather_raw_query_files(query_files_dict) - chatml_messages.append(ChatMessage(content=message_attached_files, role="user")) + chatml_messages.append(ChatMessage(content=message_attached_files, role=role)) if not is_none_or_empty(chat.get("onlineContext")): message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" @@ -411,9 +428,18 @@ def generate_chatml_messages_with_context( reconstructed_context_message = ChatMessage(content=message_context, role="user") chatml_messages.insert(0, reconstructed_context_message) - role = "user" if chat["by"] == "you" else "assistant" + if chat.get("images") and role == "assistant": + # Issue: the assistant role cannot accept an image as a message content, so send it in a separate user message. + file_attachment_message = construct_structured_message( + message=prompts.generated_image_attachment.format(), + images=chat.get("images"), + model_type=model_type, + vision_enabled=vision_enabled, + ) + chatml_messages.append(ChatMessage(content=file_attachment_message, role="user")) + message_content = construct_structured_message( - chat_message, chat.get("images"), model_type, vision_enabled, attached_file_context=query_files + chat_message, chat.get("images") if role == "user" else [], model_type, vision_enabled ) reconstructed_message = ChatMessage(content=message_content, role=role) @@ -423,6 +449,7 @@ def generate_chatml_messages_with_context( break messages = [] + if not is_none_or_empty(user_message): messages.append( ChatMessage( @@ -435,6 +462,31 @@ def generate_chatml_messages_with_context( if not is_none_or_empty(context_message): messages.append(ChatMessage(content=context_message, role="user")) + if generated_images: + messages.append( + ChatMessage( + content=construct_structured_message( + prompts.generated_image_attachment.format(), generated_images, model_type, vision_enabled + ), + role="user", + ) + ) + + if generated_files: + message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files}) + messages.append(ChatMessage(content=message_attached_files, role="assistant")) + + if generated_excalidraw_diagram: + messages.append(ChatMessage(content=prompts.generated_diagram_attachment.format(), role="assistant")) + + if additional_program_context: + messages.append( + ChatMessage( + content=prompts.additional_program_context.format(context="\n".join(additional_program_context)), + role="assistant", + ) + ) + if len(chatml_messages) > 0: messages += chatml_messages diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 6c1f71b6..1bec7f41 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -12,7 +12,7 @@ from khoj.database.models import Agent, KhojUser, TextToImageModelConfig from khoj.routers.helpers import ChatEvent, generate_better_image_prompt from khoj.routers.storage import upload_image from khoj.utils import state -from khoj.utils.helpers import ImageIntentType, convert_image_to_webp, timer +from khoj.utils.helpers import convert_image_to_webp, timer from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) @@ -34,14 +34,13 @@ async def text_to_image( status_code = 200 image = None image_url = None - intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user) if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 message = "Failed to generate image. Setup image generation on the server." - yield image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message return text2image_model = text_to_image_config.model_name @@ -53,6 +52,9 @@ async def text_to_image( elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"): chat_history += f"Q: Prompt: {chat['intent']['query']}\n" chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n" + elif chat["by"] == "khoj" and chat.get("images"): + chat_history += f"Q: {chat['intent']['query']}\n" + chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n" if send_status_func: async for event in send_status_func("**Enhancing the Painting Prompt**"): @@ -92,31 +94,29 @@ async def text_to_image( logger.error(f"Image Generation blocked by OpenAI: {e}") status_code = e.status_code # type: ignore message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore - yield image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message return else: logger.error(f"Image Generation failed with {e}", exc_info=True) message = f"Image generation failed using OpenAI" # type: ignore status_code = e.status_code # type: ignore - yield image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message return except requests.RequestException as e: logger.error(f"Image Generation failed with {e}", exc_info=True) message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error." status_code = 502 - yield image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message return # Decide how to store the generated image with timer("Upload image to S3", logger): image_url = upload_image(webp_image_bytes, user.uuid) - if image_url: - intent_type = ImageIntentType.TEXT_TO_IMAGE2 - else: - intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 + + if not image_url: image = base64.b64encode(webp_image_bytes).decode("utf-8") - yield image_url or image, status_code, image_prompt, intent_type.value + yield image_url or image, status_code, image_prompt def generate_image_with_openai( diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index d0833eec..7d2dc3a6 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -77,6 +77,7 @@ from khoj.utils.helpers import ( ) from khoj.utils.rawconfig import ( ChatRequestBody, + FileAttachment, FileFilterRequest, FilesFilterRequest, LocationData, @@ -771,6 +772,11 @@ async def chat( file_filters = conversation.file_filters if conversation and conversation.file_filters else [] attached_file_context = gather_raw_query_files(query_files) + generated_images: List[str] = [] + generated_files: List[FileAttachment] = [] + generated_excalidraw_diagram: str = None + additional_context_for_llm_response: List[str] = [] + if conversation_commands == [ConversationCommand.Default] or is_automated_task: chosen_io = await aget_data_sources_and_output_format( q, @@ -876,21 +882,17 @@ async def chat( async for result in send_llm_response(response, tracer.get("usage")): yield result - await sync_to_async(save_to_conversation_log)( - q, - response_log, - user, - meta_log, - user_message_time, - intent_type="summarize", - client_application=request.user.client_app, - conversation_id=conversation_id, - query_images=uploaded_images, - train_of_thought=train_of_thought, - raw_query_files=raw_query_files, - tracer=tracer, + summarized_document = FileAttachment( + name="Summarized Document", + content=response_log, + type="text/plain", + size=len(response_log.encode("utf-8")), ) - return + + async for result in send_event(ChatEvent.GENERATED_ASSETS, {"files": [summarized_document.model_dump()]}): + yield result + + generated_files.append(summarized_document) custom_filters = [] if conversation_commands == [ConversationCommand.Help]: @@ -1079,6 +1081,7 @@ async def chat( async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"): yield result except ValueError as e: + additional_context_for_llm_response.append(f"Failed to run code") logger.warning( f"Failed to use code tool: {e}. Attempting to respond without code results", exc_info=True, @@ -1116,51 +1119,36 @@ async def chat( if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] else: - generated_image, status_code, improved_image_prompt, intent_type = result + generated_image, status_code, improved_image_prompt = result + inferred_queries.append(improved_image_prompt) if generated_image is None or status_code != 200: - content_obj = { - "content-type": "application/json", - "intentType": intent_type, - "detail": improved_image_prompt, - "image": None, - } - async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")): + additional_context_for_llm_response.append(f"Failed to generate image with {improved_image_prompt}") + async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"): yield result - return + else: + generated_images.append(generated_image) + # content_obj = { + # "intentType": intent_type, + # "inferredQueries": [improved_image_prompt], + # "image": generated_image, + # } + # async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")): + # yield result + # return - await sync_to_async(save_to_conversation_log)( - q, - generated_image, - user, - meta_log, - user_message_time, - intent_type=intent_type, - inferred_queries=[improved_image_prompt], - client_application=request.user.client_app, - conversation_id=conversation_id, - compiled_references=compiled_references, - online_results=online_results, - code_results=code_results, - query_images=uploaded_images, - train_of_thought=train_of_thought, - raw_query_files=raw_query_files, - tracer=tracer, - ) - content_obj = { - "intentType": intent_type, - "inferredQueries": [improved_image_prompt], - "image": generated_image, - } - async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")): - yield result - return + async for result in send_event( + ChatEvent.GENERATED_ASSETS, + { + "images": [generated_image], + }, + ): + yield result if ConversationCommand.Diagram in conversation_commands: async for result in send_event(ChatEvent.STATUS, f"Creating diagram"): yield result - intent_type = "excalidraw" inferred_queries = [] diagram_description = "" @@ -1184,62 +1172,59 @@ async def chat( if better_diagram_description_prompt and excalidraw_diagram_description: inferred_queries.append(better_diagram_description_prompt) diagram_description = excalidraw_diagram_description + + generated_excalidraw_diagram = diagram_description + + async for result in send_event( + ChatEvent.GENERATED_ASSETS, + { + "excalidrawDiagram": excalidraw_diagram_description, + }, + ): + yield result else: error_message = "Failed to generate diagram. Please try again later." - async for result in send_llm_response(error_message, tracer.get("usage")): + additional_context_for_llm_response.append( + f"AI attempted to programmatically generate a diagram but failed due to a program issue. Generally, it is able to do so, but encountered a system issue this time. AI can suggest text description or rendering of the diagram or user can try again with a simpler prompt." + ) + + async for result in send_event(ChatEvent.STATUS, error_message): yield result - await sync_to_async(save_to_conversation_log)( - q, - error_message, - user, - meta_log, - user_message_time, - inferred_queries=[better_diagram_description_prompt], - client_application=request.user.client_app, - conversation_id=conversation_id, - compiled_references=compiled_references, - online_results=online_results, - code_results=code_results, - query_images=uploaded_images, - train_of_thought=train_of_thought, - raw_query_files=raw_query_files, - tracer=tracer, - ) - return + # content_obj = { + # "intentType": intent_type, + # "inferredQueries": inferred_queries, + # "image": diagram_description, + # } - content_obj = { - "intentType": intent_type, - "inferredQueries": inferred_queries, - "image": diagram_description, - } + # await sync_to_async(save_to_conversation_log)( + # q, + # excalidraw_diagram_description, + # user, + # meta_log, + # user_message_time, + # intent_type="excalidraw", + # inferred_queries=[better_diagram_description_prompt], + # client_application=request.user.client_app, + # conversation_id=conversation_id, + # compiled_references=compiled_references, + # online_results=online_results, + # code_results=code_results, + # query_images=uploaded_images, + # train_of_thought=train_of_thought, + # raw_query_files=raw_query_files, + # generated_images=generated_images, + # tracer=tracer, + # ) - await sync_to_async(save_to_conversation_log)( - q, - excalidraw_diagram_description, - user, - meta_log, - user_message_time, - intent_type="excalidraw", - inferred_queries=[better_diagram_description_prompt], - client_application=request.user.client_app, - conversation_id=conversation_id, - compiled_references=compiled_references, - online_results=online_results, - code_results=code_results, - query_images=uploaded_images, - train_of_thought=train_of_thought, - raw_query_files=raw_query_files, - tracer=tracer, - ) - - async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")): - yield result - return + # async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")): + # yield result + # return ## Generate Text Output async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): yield result + llm_response, chat_metadata = await agenerate_chat_response( defiltered_query, meta_log, @@ -1259,6 +1244,10 @@ async def chat( train_of_thought, attached_file_context, raw_query_files, + generated_images, + generated_files, + generated_excalidraw_diagram, + additional_context_for_llm_response, tracer, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index e160b8e3..30deca0c 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1184,6 +1184,10 @@ def generate_chat_response( train_of_thought: List[Any] = [], query_files: str = None, raw_query_files: List[FileAttachment] = None, + generated_images: List[str] = None, + raw_generated_files: List[FileAttachment] = [], + generated_excalidraw_diagram: str = None, + additional_context: List[str] = [], tracer: dict = {}, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables @@ -1207,6 +1211,9 @@ def generate_chat_response( query_images=query_images, train_of_thought=train_of_thought, raw_query_files=raw_query_files, + generated_images=generated_images, + raw_generated_files=raw_generated_files, + generated_excalidraw_diagram=generated_excalidraw_diagram, tracer=tracer, ) @@ -1242,6 +1249,7 @@ def generate_chat_response( user_name=user_name, agent=agent, query_files=query_files, + generated_files=raw_generated_files, tracer=tracer, ) @@ -1268,6 +1276,10 @@ def generate_chat_response( agent=agent, vision_available=vision_available, query_files=query_files, + generated_files=raw_generated_files, + generated_images=generated_images, + generated_excalidraw_diagram=generated_excalidraw_diagram, + additional_context=additional_context, tracer=tracer, ) @@ -1291,6 +1303,10 @@ def generate_chat_response( agent=agent, vision_available=vision_available, query_files=query_files, + generated_files=raw_generated_files, + generated_images=generated_images, + generated_excalidraw_diagram=generated_excalidraw_diagram, + additional_context=additional_context, tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: @@ -1313,6 +1329,10 @@ def generate_chat_response( query_images=query_images, vision_available=vision_available, query_files=query_files, + generated_files=raw_generated_files, + generated_images=generated_images, + generated_excalidraw_diagram=generated_excalidraw_diagram, + additional_context=additional_context, tracer=tracer, ) @@ -1784,6 +1804,9 @@ class MessageProcessor: self.references = {} self.usage = {} self.raw_response = "" + self.generated_images = [] + self.generated_files = [] + self.generated_excalidraw_diagrams = [] def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]: if raw_chunk.startswith("{") and raw_chunk.endswith("}"): @@ -1822,6 +1845,16 @@ class MessageProcessor: self.raw_response += chunk_data else: self.raw_response += chunk_data + elif chunk_type == ChatEvent.GENERATED_ASSETS: + chunk_data = chunk["data"] + if isinstance(chunk_data, dict): + for key in chunk_data: + if key == "images": + self.generated_images = chunk_data[key] + elif key == "files": + self.generated_files = chunk_data[key] + elif key == "excalidraw_diagrams": + self.generated_excalidraw_diagrams = chunk_data[key] def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]: if "image" in json_data or "details" in json_data: @@ -1852,7 +1885,14 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict if buffer: processor.process_message_chunk(buffer) - return {"response": processor.raw_response, "references": processor.references, "usage": processor.usage} + return { + "response": processor.raw_response, + "references": processor.references, + "usage": processor.usage, + "images": processor.generated_images, + "files": processor.generated_files, + "excalidraw_diagrams": processor.generated_excalidraw_diagrams, + } def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False): diff --git a/tests/test_offline_chat_actors.py b/tests/test_offline_chat_actors.py index b404f0e8..e84612a2 100644 --- a/tests/test_offline_chat_actors.py +++ b/tests/test_offline_chat_actors.py @@ -22,7 +22,6 @@ from khoj.processor.conversation.offline.chat_model import ( filter_questions, ) from khoj.processor.conversation.offline.utils import download_model -from khoj.processor.conversation.utils import message_to_log from khoj.utils.constants import default_offline_chat_models diff --git a/tests/test_online_chat_director.py b/tests/test_online_chat_director.py index 94545b4c..ea3a6e1c 100644 --- a/tests/test_online_chat_director.py +++ b/tests/test_online_chat_director.py @@ -6,7 +6,6 @@ from freezegun import freeze_time from khoj.database.models import Agent, Entry, KhojUser from khoj.processor.conversation import prompts -from khoj.processor.conversation.utils import message_to_log from tests.helpers import ConversationFactory, generate_chat_history, get_chat_api_key # Initialize variables for tests From 6f408948d3ccf8fc6134d5a5f8b6427e69276163 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Thu, 28 Nov 2024 20:15:10 -0800 Subject: [PATCH 02/25] Fix typing of generated_fiels parameters --- src/khoj/processor/conversation/anthropic/anthropic_chat.py | 4 ++-- src/khoj/processor/conversation/google/gemini_chat.py | 4 ++-- src/khoj/processor/conversation/offline/chat_model.py | 4 ++-- src/khoj/processor/conversation/openai/gpt.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index a4dc22ed..b71c7684 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -24,7 +24,7 @@ from khoj.utils.helpers import ( is_none_or_empty, truncate_code_context, ) -from khoj.utils.rawconfig import LocationData +from khoj.utils.rawconfig import FileAttachment, LocationData from khoj.utils.yaml import yaml_dump logger = logging.getLogger(__name__) @@ -159,7 +159,7 @@ def converse_anthropic( vision_available: bool = False, query_files: str = None, generated_images: Optional[list[str]] = None, - generated_files: List[str] = None, + generated_files: List[FileAttachment] = None, generated_excalidraw_diagram: Optional[str] = None, additional_context: Optional[str] = None, tracer: dict = {}, diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 221f44d2..4f211d23 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -24,7 +24,7 @@ from khoj.utils.helpers import ( is_none_or_empty, truncate_code_context, ) -from khoj.utils.rawconfig import LocationData +from khoj.utils.rawconfig import FileAttachment, LocationData from khoj.utils.yaml import yaml_dump logger = logging.getLogger(__name__) @@ -169,7 +169,7 @@ def converse_gemini( vision_available: bool = False, query_files: str = None, generated_images: Optional[list[str]] = None, - generated_files: List[str] = None, + generated_files: List[FileAttachment] = None, generated_excalidraw_diagram: Optional[str] = None, additional_context: List[str] = None, tracer={}, diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index f1746289..bbe54b2c 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -26,7 +26,7 @@ from khoj.utils.helpers import ( is_promptrace_enabled, truncate_code_context, ) -from khoj.utils.rawconfig import LocationData +from khoj.utils.rawconfig import FileAttachment, LocationData from khoj.utils.yaml import yaml_dump logger = logging.getLogger(__name__) @@ -162,7 +162,7 @@ def converse_offline( user_name: str = None, agent: Agent = None, query_files: str = None, - generated_files: List[str] = None, + generated_files: List[FileAttachment] = None, additional_context: List[str] = None, tracer: dict = {}, ) -> Union[ThreadedGenerator, Iterator[str]]: diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index ef3c3952..fcb9229a 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -22,7 +22,7 @@ from khoj.utils.helpers import ( is_none_or_empty, truncate_code_context, ) -from khoj.utils.rawconfig import LocationData +from khoj.utils.rawconfig import FileAttachment, LocationData from khoj.utils.yaml import yaml_dump logger = logging.getLogger(__name__) @@ -158,7 +158,7 @@ def converse( vision_available: bool = False, query_files: str = None, generated_images: Optional[list[str]] = None, - generated_files: List[str] = None, + generated_files: List[FileAttachment] = None, generated_excalidraw_diagram: Optional[str] = None, additional_context: List[str] = None, tracer: dict = {}, From 4f6d1211ba8849b9dc26be39f7f3ea44e7b1b1f1 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Thu, 28 Nov 2024 20:16:36 -0800 Subject: [PATCH 03/25] Fix additional context type in anthropic chat --- src/khoj/processor/conversation/anthropic/anthropic_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index b71c7684..e83d604c 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -161,7 +161,7 @@ def converse_anthropic( generated_images: Optional[list[str]] = None, generated_files: List[FileAttachment] = None, generated_excalidraw_diagram: Optional[str] = None, - additional_context: Optional[str] = None, + additional_context: Optional[List[str]] = None, tracer: dict = {}, ): """ From 46f647d91da112fd4158f962cbc34b9ed03128ba Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 29 Nov 2024 14:11:48 -0800 Subject: [PATCH 04/25] Improve image rendering for khoj generated images. FIx typing of stored excalidraw image. --- .../chatMessage/chatMessage.module.css | 20 +++++++++++++++++++ src/khoj/processor/conversation/utils.py | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/interface/web/app/components/chatMessage/chatMessage.module.css b/src/interface/web/app/components/chatMessage/chatMessage.module.css index b055d0a5..2abfdda7 100644 --- a/src/interface/web/app/components/chatMessage/chatMessage.module.css +++ b/src/interface/web/app/components/chatMessage/chatMessage.module.css @@ -77,6 +77,21 @@ div.imageWrapper img { border-radius: 8px; } +div.khoj div.imageWrapper img { + height: 512px; +} + +div.khoj div.imageWrapper { + flex: 1 1 auto; +} + +div.khoj div.imagesContainer { + display: flex; + flex-wrap: wrap; + flex-direction: row; + overflow-x: hidden; +} + div.chatMessageContainer > img { width: auto; height: auto; @@ -178,4 +193,9 @@ div.trainOfThoughtElement ul { div.youfullHistory { max-width: 90%; } + + div.khoj div.imageWrapper img { + width: 100%; + height: auto; + } } diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 57fc0f12..798522e6 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -293,7 +293,7 @@ def save_to_conversation_log( "turnId": turn_id, "images": generated_images, "queryFiles": [file.model_dump(mode="json") for file in raw_generated_files], - "excalidrawDiagram": str(generated_excalidraw_diagram), + "excalidrawDiagram": str(generated_excalidraw_diagram) if generated_excalidraw_diagram else None, }, conversation_log=meta_log.get("chat", []), ) From a0b00ce4a13c2f27141ca5c41d1f3034b69ce265 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 29 Nov 2024 18:10:14 -0800 Subject: [PATCH 05/25] Don't include null attributes when filling in stored conversation metadata - Prompt adjustments to indicate to LLM what context it has --- src/khoj/database/models/__init__.py | 2 +- src/khoj/processor/conversation/prompts.py | 2 +- src/khoj/processor/conversation/utils.py | 41 ++++++++++++---------- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index ae6a729d..88b15404 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -108,7 +108,7 @@ class ChatMessage(PydanticBaseModel): created: str images: Optional[List[str]] = None queryFiles: Optional[List[Dict]] = None - excalidrawDiagram: Optional[str] = None + excalidrawDiagram: Optional[List[Dict]] = None by: str turnId: Optional[str] intent: Optional[Intent] = None diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index a943cbd0..5e28a912 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -186,7 +186,7 @@ Here is the image you generated based on my query. You can follow-up with a gene generated_diagram_attachment = PromptTemplate.from_template( f""" -The AI has successfully created a diagram based on the user's query and handled the request. Good job! +The AI has successfully created a diagram based on the user's query and handled the request. Good job! This will be shared with the user. AI can follow-up with a general response or summary. Limit to 1-2 sentences. """.strip() diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 4d238b3b..86c2d87e 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -276,27 +276,32 @@ def save_to_conversation_log( ): user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") turn_id = tracer.get("mid") or str(uuid.uuid4()) + + user_message_metadata = {"created": user_message_time, "images": query_images, "turnId": turn_id} + + if raw_query_files and len(raw_query_files) > 0: + user_message_metadata["queryFiles"] = [file.model_dump(mode="json") for file in raw_query_files] + + khoj_message_metadata = { + "context": compiled_references, + "intent": {"inferred-queries": inferred_queries, "type": intent_type}, + "onlineContext": online_results, + "codeContext": code_results, + "automationId": automation_id, + "trainOfThought": train_of_thought, + "turnId": turn_id, + "images": generated_images, + "queryFiles": [file.model_dump(mode="json") for file in raw_generated_files], + } + + if generated_excalidraw_diagram: + khoj_message_metadata["excalidrawDiagram"] = generated_excalidraw_diagram + updated_conversation = message_to_log( user_message=q, chat_response=chat_response, - user_message_metadata={ - "created": user_message_time, - "images": query_images, - "turnId": turn_id, - "queryFiles": [file.model_dump(mode="json") for file in raw_query_files], - }, - khoj_message_metadata={ - "context": compiled_references, - "intent": {"inferred-queries": inferred_queries, "type": intent_type}, - "onlineContext": online_results, - "codeContext": code_results, - "automationId": automation_id, - "trainOfThought": train_of_thought, - "turnId": turn_id, - "images": generated_images, - "queryFiles": [file.model_dump(mode="json") for file in raw_generated_files], - "excalidrawDiagram": str(generated_excalidraw_diagram) if generated_excalidraw_diagram else None, - }, + user_message_metadata=user_message_metadata, + khoj_message_metadata=khoj_message_metadata, conversation_log=meta_log.get("chat", []), ) ConversationAdapters.save_conversation( From 512cf535e030ca60e683a427b4e02863764d4d71 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 29 Nov 2024 18:10:35 -0800 Subject: [PATCH 06/25] Collapse train of thought when completed during live stream --- .../web/app/components/chatHistory/chatHistory.tsx | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/interface/web/app/components/chatHistory/chatHistory.tsx b/src/interface/web/app/components/chatHistory/chatHistory.tsx index c5fd8b54..5a750e8b 100644 --- a/src/interface/web/app/components/chatHistory/chatHistory.tsx +++ b/src/interface/web/app/components/chatHistory/chatHistory.tsx @@ -54,11 +54,11 @@ function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) { const lastIndex = props.trainOfThought.length - 1; const [collapsed, setCollapsed] = useState(props.completed); - // useEffect(() => { - // if (props.completed) { - // setCollapsed(true); - // } - // }), [props.completed]; + useEffect(() => { + if (props.completed) { + setCollapsed(true); + } + }, [props.completed]); return (
Date: Fri, 29 Nov 2024 18:10:47 -0800 Subject: [PATCH 07/25] Update response handling in Obsidian to work with new format --- src/interface/obsidian/src/chat_view.ts | 154 ++++++++++++++---------- src/interface/obsidian/styles.css | 6 +- 2 files changed, 97 insertions(+), 63 deletions(-) diff --git a/src/interface/obsidian/src/chat_view.ts b/src/interface/obsidian/src/chat_view.ts index 552a54bd..9977a06f 100644 --- a/src/interface/obsidian/src/chat_view.ts +++ b/src/interface/obsidian/src/chat_view.ts @@ -1,4 +1,4 @@ -import {ItemView, MarkdownRenderer, Scope, WorkspaceLeaf, request, requestUrl, setIcon, Platform} from 'obsidian'; +import { ItemView, MarkdownRenderer, Scope, WorkspaceLeaf, request, requestUrl, setIcon, Platform } from 'obsidian'; import * as DOMPurify from 'dompurify'; import { KhojSetting } from 'src/settings'; import { KhojPaneView } from 'src/pane_view'; @@ -46,10 +46,10 @@ export class KhojChatView extends KhojPaneView { waitingForLocation: boolean; location: Location = { timezone: Intl.DateTimeFormat().resolvedOptions().timeZone }; keyPressTimeout: NodeJS.Timeout | null = null; - userMessages: string[] = []; // Store user sent messages for input history cycling - currentMessageIndex: number = -1; // Track current message index in userMessages array - private currentUserInput: string = ""; // Stores the current user input that is being typed in chat - private startingMessage: string = "Message"; + userMessages: string[] = []; // Store user sent messages for input history cycling + currentMessageIndex: number = -1; // Track current message index in userMessages array + private currentUserInput: string = ""; // Stores the current user input that is being typed in chat + private startingMessage: string = "Message"; chatMessageState: ChatMessageState; constructor(leaf: WorkspaceLeaf, setting: KhojSetting) { @@ -102,14 +102,14 @@ export class KhojChatView extends KhojPaneView { // Clear text after extracting message to send let user_message = input_el.value.trim(); - // Store the message in the array if it's not empty - if (user_message) { - this.userMessages.push(user_message); - // Update starting message after sending a new message - const modifierKey = Platform.isMacOS ? '⌘' : '^'; - this.startingMessage = `(${modifierKey}+↑/↓) for prev messages`; - input_el.placeholder = this.startingMessage; - } + // Store the message in the array if it's not empty + if (user_message) { + this.userMessages.push(user_message); + // Update starting message after sending a new message + const modifierKey = Platform.isMacOS ? '⌘' : '^'; + this.startingMessage = `(${modifierKey}+↑/↓) for prev messages`; + input_el.placeholder = this.startingMessage; + } input_el.value = ""; this.autoResize(); @@ -162,9 +162,9 @@ export class KhojChatView extends KhojPaneView { }) chatInput.addEventListener('input', (_) => { this.onChatInput() }); chatInput.addEventListener('keydown', (event) => { - this.incrementalChat(event); - this.handleArrowKeys(event); - }); + this.incrementalChat(event); + this.handleArrowKeys(event); + }); // Add event listeners for long press keybinding this.contentEl.addEventListener('keydown', this.handleKeyDown.bind(this)); @@ -199,7 +199,7 @@ export class KhojChatView extends KhojPaneView { // Get chat history from Khoj backend and set chat input state let getChatHistorySucessfully = await this.getChatHistory(chatBodyEl); - let placeholderText : string = getChatHistorySucessfully ? this.startingMessage : "Configure Khoj to enable chat"; + let placeholderText: string = getChatHistorySucessfully ? this.startingMessage : "Configure Khoj to enable chat"; chatInput.placeholder = placeholderText; chatInput.disabled = !getChatHistorySucessfully; @@ -214,7 +214,7 @@ export class KhojChatView extends KhojPaneView { }); } - startSpeechToText(event: KeyboardEvent | MouseEvent | TouchEvent, timeout=200) { + startSpeechToText(event: KeyboardEvent | MouseEvent | TouchEvent, timeout = 200) { if (!this.keyPressTimeout) { this.keyPressTimeout = setTimeout(async () => { // Reset auto send voice message timer, UI if running @@ -320,7 +320,7 @@ export class KhojChatView extends KhojPaneView { referenceButton.tabIndex = 0; // Add event listener to toggle full reference on click - referenceButton.addEventListener('click', function() { + referenceButton.addEventListener('click', function () { if (this.classList.contains("collapsed")) { this.classList.remove("collapsed"); this.classList.add("expanded"); @@ -375,7 +375,7 @@ export class KhojChatView extends KhojPaneView { referenceButton.tabIndex = 0; // Add event listener to toggle full reference on click - referenceButton.addEventListener('click', function() { + referenceButton.addEventListener('click', function () { if (this.classList.contains("collapsed")) { this.classList.remove("collapsed"); this.classList.add("expanded"); @@ -420,23 +420,23 @@ export class KhojChatView extends KhojPaneView { "Authorization": `Bearer ${this.setting.khojApiKey}`, }, }) - .then(response => response.arrayBuffer()) - .then(arrayBuffer => context.decodeAudioData(arrayBuffer)) - .then(audioBuffer => { - const source = context.createBufferSource(); - source.buffer = audioBuffer; - source.connect(context.destination); - source.start(0); - source.onended = function() { + .then(response => response.arrayBuffer()) + .then(arrayBuffer => context.decodeAudioData(arrayBuffer)) + .then(audioBuffer => { + const source = context.createBufferSource(); + source.buffer = audioBuffer; + source.connect(context.destination); + source.start(0); + source.onended = function () { + speechButton.removeChild(loader); + speechButton.disabled = false; + }; + }) + .catch(err => { + console.error("Error playing speech:", err); speechButton.removeChild(loader); - speechButton.disabled = false; - }; - }) - .catch(err => { - console.error("Error playing speech:", err); - speechButton.removeChild(loader); - speechButton.disabled = false; // Consider enabling the button again to allow retrying - }); + speechButton.disabled = false; // Consider enabling the button again to allow retrying + }); } formatHTMLMessage(message: string, raw = false, willReplace = true) { @@ -485,12 +485,18 @@ export class KhojChatView extends KhojPaneView { intentType?: string, inferredQueries?: string[], conversationId?: string, + images?: string[], + excalidrawDiagram?: string ) { if (!message) return; let chatMessageEl; - if (intentType?.includes("text-to-image") || intentType === "excalidraw") { - let imageMarkdown = this.generateImageMarkdown(message, intentType, inferredQueries, conversationId); + if ( + intentType?.includes("text-to-image") || + intentType === "excalidraw" || + (images && images.length > 0) || + excalidrawDiagram) { + let imageMarkdown = this.generateImageMarkdown(message, intentType ?? "", inferredQueries, conversationId, images, excalidrawDiagram); chatMessageEl = this.renderMessage(chatEl, imageMarkdown, sender, dt); } else { chatMessageEl = this.renderMessage(chatEl, message, sender, dt); @@ -510,7 +516,7 @@ export class KhojChatView extends KhojPaneView { chatMessageBodyEl.appendChild(this.createReferenceSection(references)); } - generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string): string { + generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string, images?: string[], excalidrawDiagram?: string): string { let imageMarkdown = ""; if (intentType === "text-to-image") { imageMarkdown = `![](data:image/png;base64,${message})`; @@ -518,11 +524,20 @@ export class KhojChatView extends KhojPaneView { imageMarkdown = `![](${message})`; } else if (intentType === "text-to-image-v3") { imageMarkdown = `![](data:image/webp;base64,${message})`; - } else if (intentType === "excalidraw") { + } else if (intentType === "excalidraw" || excalidrawDiagram) { const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`; const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}chat?conversationId=${conversationId}`; imageMarkdown = redirectMessage; + } else if (images && images.length > 0) { + for (let image of images) { + if (image.startsWith("https://")) { + imageMarkdown += `![](${image})\n\n`; + } else { + imageMarkdown += `![](data:image/png;base64,${image})\n\n`; + } + } } + if (inferredQueries) { imageMarkdown += "\n\n**Inferred Query**:"; for (let inferredQuery of inferredQueries) { @@ -650,19 +665,19 @@ export class KhojChatView extends KhojPaneView { chatBodyEl.innerHTML = ""; chatBodyEl.dataset.conversationId = ""; chatBodyEl.dataset.conversationTitle = ""; - this.userMessages = []; - this.startingMessage = "Message"; + this.userMessages = []; + this.startingMessage = "Message"; - // Update the placeholder of the chat input - const chatInput = this.contentEl.querySelector('.khoj-chat-input') as HTMLTextAreaElement; - if (chatInput) { - chatInput.placeholder = this.startingMessage; - } + // Update the placeholder of the chat input + const chatInput = this.contentEl.querySelector('.khoj-chat-input') as HTMLTextAreaElement; + if (chatInput) { + chatInput.placeholder = this.startingMessage; + } this.renderMessage(chatBodyEl, "Hey 👋🏾, what's up?", "khoj"); } async toggleChatSessions(forceShow: boolean = false): Promise { - this.userMessages = []; // clear user previous message history + this.userMessages = []; // clear user previous message history let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0] as HTMLElement; if (!forceShow && this.contentEl.getElementsByClassName("side-panel")?.length > 0) { chatBodyEl.innerHTML = ""; @@ -768,10 +783,10 @@ export class KhojChatView extends KhojPaneView { let editConversationTitleInputEl = this.contentEl.createEl('input'); editConversationTitleInputEl.classList.add("conversation-title-input"); editConversationTitleInputEl.value = conversationTitle; - editConversationTitleInputEl.addEventListener('click', function(event) { + editConversationTitleInputEl.addEventListener('click', function (event) { event.stopPropagation(); }); - editConversationTitleInputEl.addEventListener('keydown', function(event) { + editConversationTitleInputEl.addEventListener('keydown', function (event) { if (event.key === "Enter") { event.preventDefault(); editConversationTitleSaveButtonEl.click(); @@ -890,15 +905,17 @@ export class KhojChatView extends KhojPaneView { chatLog.intent?.type, chatLog.intent?.["inferred-queries"], chatBodyEl.dataset.conversationId ?? "", + chatLog.images, + chatLog.excalidrawDiagram, ); // push the user messages to the chat history - if(chatLog.by === "you"){ + if (chatLog.by === "you") { this.userMessages.push(chatLog.message); } }); // Update starting message after loading history - const modifierKey: string = Platform.isMacOS ? '⌘' : '^'; + const modifierKey: string = Platform.isMacOS ? '⌘' : '^'; this.startingMessage = this.userMessages.length > 0 ? `(${modifierKey}+↑/↓) for prev messages` : "Message"; @@ -922,15 +939,15 @@ export class KhojChatView extends KhojPaneView { try { let jsonChunk = JSON.parse(rawChunk); if (!jsonChunk.type) - jsonChunk = {type: 'message', data: jsonChunk}; + jsonChunk = { type: 'message', data: jsonChunk }; return jsonChunk; } catch (e) { - return {type: 'message', data: rawChunk}; + return { type: 'message', data: rawChunk }; } } else if (rawChunk.length > 0) { - return {type: 'message', data: rawChunk}; + return { type: 'message', data: rawChunk }; } - return {type: '', data: ''}; + return { type: '', data: '' }; } processMessageChunk(rawChunk: string): void { @@ -965,7 +982,7 @@ export class KhojChatView extends KhojPaneView { isVoice: false, }; } else if (chunk.type === "references") { - this.chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext}; + this.chatMessageState.references = { "notes": chunk.data.context, "online": chunk.data.onlineContext }; } else if (chunk.type === 'message') { const chunkData = chunk.data; if (typeof chunkData === 'object' && chunkData !== null) { @@ -988,7 +1005,7 @@ export class KhojChatView extends KhojPaneView { } handleJsonResponse(jsonData: any): void { - if (jsonData.image || jsonData.detail) { + if (jsonData.image || jsonData.detail || jsonData.images || jsonData.excalidrawDiagram) { this.chatMessageState.rawResponse = this.handleImageResponse(jsonData, this.chatMessageState.rawResponse); } else if (jsonData.response) { this.chatMessageState.rawResponse = jsonData.response; @@ -1234,11 +1251,11 @@ export class KhojChatView extends KhojPaneView { const recordingConfig = { mimeType: 'audio/webm' }; this.mediaRecorder = new MediaRecorder(stream, recordingConfig); - this.mediaRecorder.addEventListener("dataavailable", function(event) { + this.mediaRecorder.addEventListener("dataavailable", function (event) { if (event.data.size > 0) audioChunks.push(event.data); }); - this.mediaRecorder.addEventListener("stop", async function() { + this.mediaRecorder.addEventListener("stop", async function () { const audioBlob = new Blob(audioChunks, { type: 'audio/webm' }); await sendToServer(audioBlob); }); @@ -1368,7 +1385,22 @@ export class KhojChatView extends KhojPaneView { if (inferredQuery) { rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; } + } else if (imageJson.images) { + // If response has images field, response is a list of generated images. + imageJson.images.forEach((image: any) => { + + if (image.startsWith("http")) { + rawResponse += `![generated_image](${image})\n\n`; + } else { + rawResponse += `![generated_image](data:image/png;base64,${image})\n\n`; + } + }); + } else if (imageJson.excalidrawDiagram) { + const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`; + const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}`; + rawResponse += redirectMessage; } + // If response has detail field, response is an error message. if (imageJson.detail) rawResponse += imageJson.detail; @@ -1407,7 +1439,7 @@ export class KhojChatView extends KhojPaneView { referenceExpandButton.classList.add("reference-expand-button"); referenceExpandButton.innerHTML = numReferences == 1 ? "1 reference" : `${numReferences} references`; - referenceExpandButton.addEventListener('click', function() { + referenceExpandButton.addEventListener('click', function () { if (referenceSection.classList.contains("collapsed")) { referenceSection.classList.remove("collapsed"); referenceSection.classList.add("expanded"); diff --git a/src/interface/obsidian/styles.css b/src/interface/obsidian/styles.css index b02b2ff3..dea8c7f2 100644 --- a/src/interface/obsidian/styles.css +++ b/src/interface/obsidian/styles.css @@ -82,7 +82,8 @@ If your plugin does not need CSS, delete this file. } /* color chat bubble by khoj blue */ .khoj-chat-message-text.khoj { - border: 1px solid var(--khoj-sun); + border-top: 1px solid var(--khoj-sun); + border-radius: 0px; margin-left: auto; white-space: pre-line; } @@ -104,8 +105,9 @@ If your plugin does not need CSS, delete this file. } /* color chat bubble by you dark grey */ .khoj-chat-message-text.you { - border: 1px solid var(--color-accent); + color: var(--text-normal); margin-right: auto; + background-color: var(--background-modifier-cover); } /* add right protrusion to you chat bubble */ .khoj-chat-message-text.you:after { From 2b32f0e80d77608ad99625ae5cdb782bbfa14610 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 29 Nov 2024 18:11:50 -0800 Subject: [PATCH 08/25] Remove commented out code blocks --- src/khoj/routers/api_chat.py | 38 ------------------------------------ 1 file changed, 38 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 7d2dc3a6..156658a7 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1128,14 +1128,6 @@ async def chat( yield result else: generated_images.append(generated_image) - # content_obj = { - # "intentType": intent_type, - # "inferredQueries": [improved_image_prompt], - # "image": generated_image, - # } - # async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")): - # yield result - # return async for result in send_event( ChatEvent.GENERATED_ASSETS, @@ -1191,36 +1183,6 @@ async def chat( async for result in send_event(ChatEvent.STATUS, error_message): yield result - # content_obj = { - # "intentType": intent_type, - # "inferredQueries": inferred_queries, - # "image": diagram_description, - # } - - # await sync_to_async(save_to_conversation_log)( - # q, - # excalidraw_diagram_description, - # user, - # meta_log, - # user_message_time, - # intent_type="excalidraw", - # inferred_queries=[better_diagram_description_prompt], - # client_application=request.user.client_app, - # conversation_id=conversation_id, - # compiled_references=compiled_references, - # online_results=online_results, - # code_results=code_results, - # query_images=uploaded_images, - # train_of_thought=train_of_thought, - # raw_query_files=raw_query_files, - # generated_images=generated_images, - # tracer=tracer, - # ) - - # async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")): - # yield result - # return - ## Generate Text Output async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): yield result From e3aee50cf396a346ecb16fe9b64b97df19d0dcfe Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 29 Nov 2024 18:41:53 -0800 Subject: [PATCH 09/25] Fix parsing of generated_asset response --- src/interface/obsidian/src/chat_view.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/interface/obsidian/src/chat_view.ts b/src/interface/obsidian/src/chat_view.ts index 9977a06f..993cfacf 100644 --- a/src/interface/obsidian/src/chat_view.ts +++ b/src/interface/obsidian/src/chat_view.ts @@ -958,6 +958,10 @@ export class KhojChatView extends KhojPaneView { console.log(`status: ${chunk.data}`); const statusMessage = chunk.data; this.handleStreamResponse(this.chatMessageState.newResponseTextEl, statusMessage, this.chatMessageState.loadingEllipsis, false); + } else if (chunk.type === 'generated_assets') { + const generatedAssets = chunk.data; + const imageData = this.handleImageResponse(generatedAssets, this.chatMessageState.rawResponse); + this.handleStreamResponse(this.chatMessageState.newResponseTextEl, imageData, this.chatMessageState.loadingEllipsis, false); } else if (chunk.type === 'start_llm_response') { console.log("Started streaming", new Date()); } else if (chunk.type === 'end_llm_response') { From dc4a9ee3e18cf269369b290df6759f1b2c492335 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 30 Nov 2024 12:31:20 -0800 Subject: [PATCH 10/25] Ensure that the generated assets are maintained in the chat window after streaming is completed. --- src/interface/obsidian/src/chat_view.ts | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/interface/obsidian/src/chat_view.ts b/src/interface/obsidian/src/chat_view.ts index 993cfacf..abb0de50 100644 --- a/src/interface/obsidian/src/chat_view.ts +++ b/src/interface/obsidian/src/chat_view.ts @@ -27,6 +27,7 @@ interface ChatMessageState { newResponseEl: HTMLElement | null; loadingEllipsis: HTMLElement | null; references: any; + generatedAssets: string; rawResponse: string; rawQuery: string; isVoice: boolean; @@ -961,6 +962,7 @@ export class KhojChatView extends KhojPaneView { } else if (chunk.type === 'generated_assets') { const generatedAssets = chunk.data; const imageData = this.handleImageResponse(generatedAssets, this.chatMessageState.rawResponse); + this.chatMessageState.generatedAssets = imageData; this.handleStreamResponse(this.chatMessageState.newResponseTextEl, imageData, this.chatMessageState.loadingEllipsis, false); } else if (chunk.type === 'start_llm_response') { console.log("Started streaming", new Date()); @@ -984,6 +986,7 @@ export class KhojChatView extends KhojPaneView { rawResponse: "", rawQuery: liveQuery, isVoice: false, + generatedAssets: "", }; } else if (chunk.type === "references") { this.chatMessageState.references = { "notes": chunk.data.context, "online": chunk.data.onlineContext }; @@ -999,11 +1002,11 @@ export class KhojChatView extends KhojPaneView { this.handleJsonResponse(jsonData); } catch (e) { this.chatMessageState.rawResponse += chunkData; - this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis); + this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse + this.chatMessageState.generatedAssets, this.chatMessageState.loadingEllipsis); } } else { this.chatMessageState.rawResponse += chunkData; - this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis); + this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse + this.chatMessageState.generatedAssets, this.chatMessageState.loadingEllipsis); } } } From a539761c49b34d701f4ff44796e13b124a88c306 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 30 Nov 2024 12:35:13 -0800 Subject: [PATCH 11/25] Fix processing of excalidrawdiagram in json response chunking --- src/khoj/routers/helpers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 92b2b1aa..d54ae4f7 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1807,7 +1807,7 @@ class MessageProcessor: self.raw_response = "" self.generated_images = [] self.generated_files = [] - self.generated_excalidraw_diagrams = [] + self.generated_excalidraw_diagram = [] def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]: if raw_chunk.startswith("{") and raw_chunk.endswith("}"): @@ -1854,8 +1854,8 @@ class MessageProcessor: self.generated_images = chunk_data[key] elif key == "files": self.generated_files = chunk_data[key] - elif key == "excalidraw_diagrams": - self.generated_excalidraw_diagrams = chunk_data[key] + elif key == "excalidrawDiagram": + self.generated_excalidraw_diagram = chunk_data[key] def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]: if "image" in json_data or "details" in json_data: @@ -1892,7 +1892,7 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict "usage": processor.usage, "images": processor.generated_images, "files": processor.generated_files, - "excalidraw_diagrams": processor.generated_excalidraw_diagrams, + "excalidrawDiagram": processor.generated_excalidraw_diagram, } From 991577aa17de8cb026e53fea76b7d81d066bd6e8 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 30 Nov 2024 14:39:08 -0800 Subject: [PATCH 12/25] Allow a None turnId to accommodate historic chat messages --- src/khoj/database/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 88b15404..521eee9b 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -110,7 +110,7 @@ class ChatMessage(PydanticBaseModel): queryFiles: Optional[List[Dict]] = None excalidrawDiagram: Optional[List[Dict]] = None by: str - turnId: Optional[str] + turnId: Optional[str] = None intent: Optional[Intent] = None automationId: Optional[str] = None From 224abd14e0e15fc3f41e73ee04d3ea1ebe269f0a Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 30 Nov 2024 14:39:27 -0800 Subject: [PATCH 13/25] Only add the image_url to the constructed chat message if it is a url --- src/khoj/processor/conversation/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 86c2d87e..a24ef899 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -344,7 +344,8 @@ def construct_structured_message( constructed_messages.append({"type": "text", "text": attached_file_context}) if vision_enabled and images: for image in images: - constructed_messages.append({"type": "image_url", "image_url": {"url": image}}) + if image.startswith("https://"): + constructed_messages.append({"type": "image_url", "image_url": {"url": image}}) return constructed_messages if not is_none_or_empty(attached_file_context): From 00f48dc1e8e659a2b4e84dba3fd76c340020d3ca Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 30 Nov 2024 14:39:51 -0800 Subject: [PATCH 14/25] If in the new images format, show the response text in obsidian instead of the inferred query --- src/interface/obsidian/src/chat_view.ts | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/interface/obsidian/src/chat_view.ts b/src/interface/obsidian/src/chat_view.ts index abb0de50..18001f47 100644 --- a/src/interface/obsidian/src/chat_view.ts +++ b/src/interface/obsidian/src/chat_view.ts @@ -537,12 +537,16 @@ export class KhojChatView extends KhojPaneView { imageMarkdown += `![](data:image/png;base64,${image})\n\n`; } } + + imageMarkdown += `${message}`; } - if (inferredQueries) { - imageMarkdown += "\n\n**Inferred Query**:"; - for (let inferredQuery of inferredQueries) { - imageMarkdown += `\n\n${inferredQuery}`; + if (!images || images.length === 0) { + if (inferredQueries) { + imageMarkdown += "\n\n**Inferred Query**:"; + for (let inferredQuery of inferredQueries) { + imageMarkdown += `\n\n${inferredQuery}`; + } } } return imageMarkdown; From c87fce5930f931529bb67c0e045078dcc296adfe Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 1 Dec 2024 18:35:31 -0800 Subject: [PATCH 15/25] Add a migration to use the new image storage format for past conversations - Added it to the Django migrations so that it auto-triggers when someone updates their server and starts it up again for the first time. This will require that they update their clients as well in order to view/consume image content. - Remove server-side references in the code that allow to parse the text-to-image intent as it will no longer be necessary, given the chat logs will be migrated --- .../components/chatMessage/chatMessage.tsx | 21 ----- src/khoj/database/admin.py | 13 +-- ...5_migrate_generated_assets_and_validate.py | 85 +++++++++++++++++++ .../conversation/anthropic/anthropic_chat.py | 2 +- .../conversation/google/gemini_chat.py | 2 +- .../conversation/offline/chat_model.py | 2 +- src/khoj/processor/conversation/utils.py | 5 +- src/khoj/processor/image/generate.py | 3 - 8 files changed, 92 insertions(+), 41 deletions(-) create mode 100644 src/khoj/database/migrations/0075_migrate_generated_assets_and_validate.py diff --git a/src/interface/web/app/components/chatMessage/chatMessage.tsx b/src/interface/web/app/components/chatMessage/chatMessage.tsx index e4e35952..49ce4c00 100644 --- a/src/interface/web/app/components/chatMessage/chatMessage.tsx +++ b/src/interface/web/app/components/chatMessage/chatMessage.tsx @@ -413,27 +413,6 @@ const ChatMessage = forwardRef((props, ref) => .replace(/\\\[/g, "LEFTBRACKET") .replace(/\\\]/g, "RIGHTBRACKET"); - const intentTypeHandlers = { - "text-to-image": (msg: string) => `![generated image](data:image/png;base64,${msg})`, - "text-to-image2": (msg: string) => `![generated image](${msg})`, - "text-to-image-v3": (msg: string) => - `![generated image](data:image/webp;base64,${msg})`, - excalidraw: (msg: string) => msg, - }; - - // Handle intent-specific rendering - if (props.chatMessage.intent) { - const { type, "inferred-queries": inferredQueries } = props.chatMessage.intent; - - if (type in intentTypeHandlers) { - message = intentTypeHandlers[type as keyof typeof intentTypeHandlers](message); - } - - if (type.includes("text-to-image") && inferredQueries?.length > 0) { - message += `\n\n${inferredQueries[0]}`; - } - } - // Replace file links with base64 data message = renderCodeGenImageInline(message, props.chatMessage.codeContext); diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 906f2ffe..b71f1f81 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -286,17 +286,10 @@ class ConversationAdmin(admin.ModelAdmin): modified_log = conversation.conversation_log chat_log = modified_log.get("chat", []) for idx, log in enumerate(chat_log): - if ( - log["by"] == "khoj" - and log["intent"] - and log["intent"]["type"] - and ( - log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value - or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value - ) - ): - log["message"] = "inline image redacted for space" + if log["by"] == "khoj" and log["images"]: + log["images"] = ["inline image redacted for space"] chat_log[idx] = log + modified_log["chat"] = chat_log writer.writerow( diff --git a/src/khoj/database/migrations/0075_migrate_generated_assets_and_validate.py b/src/khoj/database/migrations/0075_migrate_generated_assets_and_validate.py new file mode 100644 index 00000000..40c74ebf --- /dev/null +++ b/src/khoj/database/migrations/0075_migrate_generated_assets_and_validate.py @@ -0,0 +1,85 @@ +# Made manually by sabaimran for use by Django 5.0.9 on 2024-12-01 16:59 + +from django.db import migrations, models + +# This script was written alongside when Pydantic validation was added to the Conversation conversation_log field. + + +def migrate_generated_assets(apps, schema_editor): + Conversation = apps.get_model("database", "Conversation") + + # Process conversations in chunks + for conversation in Conversation.objects.iterator(): + try: + meta_log = conversation.conversation_log + modified = False + + for chat in meta_log.get("chat", []): + intent_type = chat.get("intent", {}).get("type") + + if intent_type and chat["by"] == "khoj": + if intent_type and "text-to-image" in intent_type: + # Migrate the generated image to the new format + chat["images"] = [chat.get("message")] + chat["message"] = chat["intent"]["inferred-queries"][0] + modified = True + + if intent_type and "excalidraw" in intent_type: + # Migrate the generated excalidraw to the new format + chat["excalidrawDiagram"] = chat.get("message") + chat["message"] = chat["intent"]["inferred-queries"][0] + modified = True + + # Only save if changes were made + if modified: + conversation.conversation_log = meta_log + conversation.save() + + except Exception as e: + print(f"Error processing conversation {conversation.id}: {str(e)}") + continue + + +def reverse_migration(apps, schema_editor): + Conversation = apps.get_model("database", "Conversation") + + # Process conversations in chunks + for conversation in Conversation.objects.iterator(): + try: + meta_log = conversation.conversation_log + modified = False + + for chat in meta_log.get("chat", []): + intent_type = chat.get("intent", {}).get("type") + + if intent_type and chat["by"] == "khoj": + if intent_type and "text-to-image" in intent_type: + # Migrate the generated image back to the old format + chat["message"] = chat.get("images", [])[0] + chat.pop("images", None) + modified = True + + if intent_type and "excalidraw" in intent_type: + # Migrate the generated excalidraw back to the old format + chat["message"] = chat.get("excalidrawDiagram") + chat.pop("excalidrawDiagram", None) + modified = True + + # Only save if changes were made + if modified: + conversation.conversation_log = meta_log + conversation.save() + + except Exception as e: + print(f"Error processing conversation {conversation.id}: {str(e)}") + continue + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0074_alter_conversation_title"), + ] + + operations = [ + migrations.RunPython(migrate_generated_assets, reverse_migration), + ] diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 15f9fa17..e72146e5 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -55,7 +55,7 @@ def extract_questions_anthropic( [ f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n' for chat in conversation_log.get("chat", [])[-4:] - if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type") + if chat["by"] == "khoj" ] ) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 14f28303..fc49e35f 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -56,7 +56,7 @@ def extract_questions_gemini( [ f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n' for chat in conversation_log.get("chat", [])[-4:] - if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type") + if chat["by"] == "khoj" ] ) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 7db70bc1..d493dd30 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -69,7 +69,7 @@ def extract_questions_offline( if use_history: for chat in conversation_log.get("chat", [])[-4:]: - if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type"): + if chat["by"] == "khoj": chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"Khoj: {chat['message']}\n\n" diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index a24ef899..64d42716 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -154,9 +154,6 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n' chat_history += f"{agent_name}: {chat['message']}\n\n" - elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")): - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: [generated image redacted for space]\n" elif chat["by"] == "khoj" and chat.get("images"): chat_history += f"User: {chat['intent']['query']}\n" chat_history += f"{agent_name}: [generated image redacted for space]\n" @@ -320,7 +317,7 @@ def save_to_conversation_log( Saved Conversation Turn You ({user.username}): "{q}" -Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}" +Khoj: "{chat_response}" """.strip() ) diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 1bec7f41..e543ac7d 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -49,9 +49,6 @@ async def text_to_image( if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]: chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"A: {chat['message']}\n" - elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"): - chat_history += f"Q: Prompt: {chat['intent']['query']}\n" - chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n" elif chat["by"] == "khoj" and chat.get("images"): chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n" From 355203282748cefe11304f8c0c3b54ad6cbb1a3d Mon Sep 17 00:00:00 2001 From: sabaimran Date: Tue, 3 Dec 2024 21:23:15 -0800 Subject: [PATCH 16/25] Rename additional context to additional_context_for_llm_response --- src/interface/obsidian/src/chat_view.ts | 10 ++++------ .../web/app/components/chatMessage/chatMessage.tsx | 5 ----- src/khoj/database/models/__init__.py | 10 +++++----- .../processor/conversation/anthropic/anthropic_chat.py | 4 ++-- src/khoj/processor/conversation/google/gemini_chat.py | 4 ++-- src/khoj/processor/conversation/offline/chat_model.py | 2 +- src/khoj/processor/conversation/openai/gpt.py | 4 ++-- src/khoj/processor/conversation/prompts.py | 4 +--- src/khoj/processor/conversation/utils.py | 8 +++++--- src/khoj/routers/helpers.py | 8 ++++---- 10 files changed, 26 insertions(+), 33 deletions(-) diff --git a/src/interface/obsidian/src/chat_view.ts b/src/interface/obsidian/src/chat_view.ts index 18001f47..1dcea196 100644 --- a/src/interface/obsidian/src/chat_view.ts +++ b/src/interface/obsidian/src/chat_view.ts @@ -541,12 +541,10 @@ export class KhojChatView extends KhojPaneView { imageMarkdown += `${message}`; } - if (!images || images.length === 0) { - if (inferredQueries) { - imageMarkdown += "\n\n**Inferred Query**:"; - for (let inferredQuery of inferredQueries) { - imageMarkdown += `\n\n${inferredQuery}`; - } + if (images?.length === 0 && inferredQueries) { + imageMarkdown += "\n\n**Inferred Query**:"; + for (let inferredQuery of inferredQueries) { + imageMarkdown += `\n\n${inferredQuery}`; } } return imageMarkdown; diff --git a/src/interface/web/app/components/chatMessage/chatMessage.tsx b/src/interface/web/app/components/chatMessage/chatMessage.tsx index 49ce4c00..89c3038a 100644 --- a/src/interface/web/app/components/chatMessage/chatMessage.tsx +++ b/src/interface/web/app/components/chatMessage/chatMessage.tsx @@ -397,11 +397,6 @@ const ChatMessage = forwardRef((props, ref) => // Prepare initial message for rendering let message = props.chatMessage.message; - if (props.chatMessage.intent && props.chatMessage.intent.type == "excalidraw") { - message = props.chatMessage.intent["inferred-queries"][0]; - setExcalidrawData(props.chatMessage.message); - } - if (props.chatMessage.excalidrawDiagram) { setExcalidrawData(props.chatMessage.excalidrawDiagram); } diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 521eee9b..3f81fead 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -65,12 +65,12 @@ class PeopleAlsoAsk(PydanticBaseModel): class KnowledgeGraph(PydanticBaseModel): attributes: Dict[str, str] - description: str - descriptionLink: str - descriptionSource: str - imageUrl: str + description: Optional[str] = None + descriptionLink: Optional[str] = None + descriptionSource: Optional[str] = None + imageUrl: Optional[str] = None title: str - type: str + type: Optional[str] = None class OrganicContext(PydanticBaseModel): diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index e72146e5..0a242318 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -160,7 +160,7 @@ def converse_anthropic( generated_images: Optional[list[str]] = None, generated_files: List[FileAttachment] = None, generated_excalidraw_diagram: Optional[str] = None, - additional_context: Optional[List[str]] = None, + additional_context_for_llm_response: Optional[List[str]] = None, tracer: dict = {}, ): """ @@ -224,7 +224,7 @@ def converse_anthropic( generated_excalidraw_diagram=generated_excalidraw_diagram, generated_files=generated_files, generated_images=generated_images, - additional_program_context=additional_context, + additional_context_for_llm_response=additional_context_for_llm_response, ) messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index fc49e35f..304511ca 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -170,7 +170,7 @@ def converse_gemini( generated_images: Optional[list[str]] = None, generated_files: List[FileAttachment] = None, generated_excalidraw_diagram: Optional[str] = None, - additional_context: List[str] = None, + additional_context_for_llm_response: List[str] = None, tracer={}, ): """ @@ -235,7 +235,7 @@ def converse_gemini( generated_excalidraw_diagram=generated_excalidraw_diagram, generated_files=generated_files, generated_images=generated_images, - additional_program_context=additional_context, + additional_context_for_llm_response=additional_context_for_llm_response, ) messages, system_prompt = format_messages_for_gemini(messages, system_prompt) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index d493dd30..853f95ec 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -234,7 +234,7 @@ def converse_offline( model_type=ChatModelOptions.ModelType.OFFLINE, query_files=query_files, generated_files=generated_files, - additional_program_context=additional_context, + additional_context_for_llm_response=additional_context, ) logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}") diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 9cfb9620..518e655d 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -160,7 +160,7 @@ def converse( generated_images: Optional[list[str]] = None, generated_files: List[FileAttachment] = None, generated_excalidraw_diagram: Optional[str] = None, - additional_context: List[str] = None, + additional_context_for_llm_response: List[str] = None, tracer: dict = {}, ): """ @@ -226,7 +226,7 @@ def converse( generated_excalidraw_diagram=generated_excalidraw_diagram, generated_files=generated_files, generated_images=generated_images, - additional_program_context=additional_context, + additional_context_for_llm_response=additional_context_for_llm_response, ) logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 5e28a912..46dac655 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -186,9 +186,7 @@ Here is the image you generated based on my query. You can follow-up with a gene generated_diagram_attachment = PromptTemplate.from_template( f""" -The AI has successfully created a diagram based on the user's query and handled the request. Good job! This will be shared with the user. - -AI can follow-up with a general response or summary. Limit to 1-2 sentences. +I've successfully created a diagram based on the user's query. The diagram will automatically be shared with the user. I can follow-up with a general response or summary. Limit to 1-2 sentences. """.strip() ) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 64d42716..9a7cb24b 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -383,7 +383,7 @@ def generate_chatml_messages_with_context( generated_images: Optional[list[str]] = None, generated_files: List[FileAttachment] = None, generated_excalidraw_diagram: str = None, - additional_program_context: List[str] = [], + additional_context_for_llm_response: List[str] = [], ): """Generate chat messages with appropriate context from previous conversation to send to the chat model""" # Set max prompt size from user config or based on pre-configured for model and machine specs @@ -484,10 +484,12 @@ def generate_chatml_messages_with_context( if generated_excalidraw_diagram: messages.append(ChatMessage(content=prompts.generated_diagram_attachment.format(), role="assistant")) - if additional_program_context: + if additional_context_for_llm_response: messages.append( ChatMessage( - content=prompts.additional_program_context.format(context="\n".join(additional_program_context)), + content=prompts.additional_program_context.format( + context="\n".join(additional_context_for_llm_response) + ), role="assistant", ) ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index d54ae4f7..7d61752c 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1188,7 +1188,7 @@ def generate_chat_response( generated_images: List[str] = None, raw_generated_files: List[FileAttachment] = [], generated_excalidraw_diagram: str = None, - additional_context: List[str] = [], + additional_context_for_llm_response: List[str] = [], tracer: dict = {}, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables @@ -1280,7 +1280,7 @@ def generate_chat_response( generated_files=raw_generated_files, generated_images=generated_images, generated_excalidraw_diagram=generated_excalidraw_diagram, - additional_context=additional_context, + additional_context_for_llm_response=additional_context_for_llm_response, tracer=tracer, ) @@ -1307,7 +1307,7 @@ def generate_chat_response( generated_files=raw_generated_files, generated_images=generated_images, generated_excalidraw_diagram=generated_excalidraw_diagram, - additional_context=additional_context, + additional_context_for_llm_response=additional_context_for_llm_response, tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: @@ -1333,7 +1333,7 @@ def generate_chat_response( generated_files=raw_generated_files, generated_images=generated_images, generated_excalidraw_diagram=generated_excalidraw_diagram, - additional_context=additional_context, + additional_context_for_llm_response=additional_context_for_llm_response, tracer=tracer, ) From df5e34615af350fd723ef6efde76df9670b9f25b Mon Sep 17 00:00:00 2001 From: sabaimran Date: Tue, 3 Dec 2024 21:26:55 -0800 Subject: [PATCH 17/25] Fix processing of images field when construct chat messages --- src/khoj/processor/conversation/utils.py | 27 ++++++++++++------------ 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 9a7cb24b..b28abcae 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -433,19 +433,20 @@ def generate_chatml_messages_with_context( reconstructed_context_message = ChatMessage(content=message_context, role="user") chatml_messages.insert(0, reconstructed_context_message) - if chat.get("images") and role == "assistant": - # Issue: the assistant role cannot accept an image as a message content, so send it in a separate user message. - file_attachment_message = construct_structured_message( - message=prompts.generated_image_attachment.format(), - images=chat.get("images"), - model_type=model_type, - vision_enabled=vision_enabled, - ) - chatml_messages.append(ChatMessage(content=file_attachment_message, role="user")) - - message_content = construct_structured_message( - chat_message, chat.get("images") if role == "user" else [], model_type, vision_enabled - ) + if chat.get("images"): + if role == "assistant": + # Issue: the assistant role cannot accept an image as a message content, so send it in a separate user message. + file_attachment_message = construct_structured_message( + message=prompts.generated_image_attachment.format(), + images=chat.get("images"), + model_type=model_type, + vision_enabled=vision_enabled, + ) + chatml_messages.append(ChatMessage(content=file_attachment_message, role="user")) + else: + message_content = construct_structured_message( + chat_message, chat.get("images"), model_type, vision_enabled + ) reconstructed_message = ChatMessage(content=message_content, role=role) chatml_messages.insert(0, reconstructed_message) From 8953ac03ec54d3e5314cc47cffc9cedb5a1ee620 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 4 Dec 2024 18:43:41 -0800 Subject: [PATCH 18/25] Rename additional context for llm response to program execution context --- .../processor/conversation/anthropic/anthropic_chat.py | 4 ++-- src/khoj/processor/conversation/google/gemini_chat.py | 4 ++-- src/khoj/processor/conversation/offline/chat_model.py | 2 +- src/khoj/processor/conversation/openai/gpt.py | 4 ++-- src/khoj/processor/conversation/prompts.py | 5 ++--- src/khoj/processor/conversation/utils.py | 8 +++----- src/khoj/routers/api_chat.py | 10 +++++----- src/khoj/routers/helpers.py | 8 ++++---- 8 files changed, 21 insertions(+), 24 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 0a242318..688514ca 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -160,7 +160,7 @@ def converse_anthropic( generated_images: Optional[list[str]] = None, generated_files: List[FileAttachment] = None, generated_excalidraw_diagram: Optional[str] = None, - additional_context_for_llm_response: Optional[List[str]] = None, + program_execution_context: Optional[List[str]] = None, tracer: dict = {}, ): """ @@ -224,7 +224,7 @@ def converse_anthropic( generated_excalidraw_diagram=generated_excalidraw_diagram, generated_files=generated_files, generated_images=generated_images, - additional_context_for_llm_response=additional_context_for_llm_response, + program_execution_context=program_execution_context, ) messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 304511ca..ad10acda 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -170,7 +170,7 @@ def converse_gemini( generated_images: Optional[list[str]] = None, generated_files: List[FileAttachment] = None, generated_excalidraw_diagram: Optional[str] = None, - additional_context_for_llm_response: List[str] = None, + program_execution_context: List[str] = None, tracer={}, ): """ @@ -235,7 +235,7 @@ def converse_gemini( generated_excalidraw_diagram=generated_excalidraw_diagram, generated_files=generated_files, generated_images=generated_images, - additional_context_for_llm_response=additional_context_for_llm_response, + program_execution_context=program_execution_context, ) messages, system_prompt = format_messages_for_gemini(messages, system_prompt) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 853f95ec..d81c194c 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -234,7 +234,7 @@ def converse_offline( model_type=ChatModelOptions.ModelType.OFFLINE, query_files=query_files, generated_files=generated_files, - additional_context_for_llm_response=additional_context, + program_execution_context=additional_context, ) logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}") diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 518e655d..c8faf25e 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -160,7 +160,7 @@ def converse( generated_images: Optional[list[str]] = None, generated_files: List[FileAttachment] = None, generated_excalidraw_diagram: Optional[str] = None, - additional_context_for_llm_response: List[str] = None, + program_execution_context: List[str] = None, tracer: dict = {}, ): """ @@ -226,7 +226,7 @@ def converse( generated_excalidraw_diagram=generated_excalidraw_diagram, generated_files=generated_files, generated_images=generated_images, - additional_context_for_llm_response=additional_context_for_llm_response, + program_execution_context=program_execution_context, ) logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 46dac655..04ee8b13 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -1043,12 +1043,11 @@ A: additional_program_context = PromptTemplate.from_template( """ -Here's some additional context about what happened while I was executing this query: +Here are some additional results from the query execution: {context} - """.strip() +""".strip() ) - personality_prompt_safety_expert_lax = PromptTemplate.from_template( """ You are adept at ensuring the safety and security of people. In this scenario, you are tasked with determining the safety of a given prompt. diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b28abcae..d68e592e 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -383,7 +383,7 @@ def generate_chatml_messages_with_context( generated_images: Optional[list[str]] = None, generated_files: List[FileAttachment] = None, generated_excalidraw_diagram: str = None, - additional_context_for_llm_response: List[str] = [], + program_execution_context: List[str] = [], ): """Generate chat messages with appropriate context from previous conversation to send to the chat model""" # Set max prompt size from user config or based on pre-configured for model and machine specs @@ -485,12 +485,10 @@ def generate_chatml_messages_with_context( if generated_excalidraw_diagram: messages.append(ChatMessage(content=prompts.generated_diagram_attachment.format(), role="assistant")) - if additional_context_for_llm_response: + if program_execution_context: messages.append( ChatMessage( - content=prompts.additional_program_context.format( - context="\n".join(additional_context_for_llm_response) - ), + content=prompts.additional_program_context.format(context="\n".join(program_execution_context)), role="assistant", ) ) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index b54b163d..68d77475 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -774,7 +774,7 @@ async def chat( generated_images: List[str] = [] generated_files: List[FileAttachment] = [] generated_excalidraw_diagram: str = None - additional_context_for_llm_response: List[str] = [] + program_execution_context: List[str] = [] if conversation_commands == [ConversationCommand.Default] or is_automated_task: chosen_io = await aget_data_sources_and_output_format( @@ -1080,7 +1080,7 @@ async def chat( async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"): yield result except ValueError as e: - additional_context_for_llm_response.append(f"Failed to run code") + program_execution_context.append(f"Failed to run code") logger.warning( f"Failed to use code tool: {e}. Attempting to respond without code results", exc_info=True, @@ -1122,7 +1122,7 @@ async def chat( inferred_queries.append(improved_image_prompt) if generated_image is None or status_code != 200: - additional_context_for_llm_response.append(f"Failed to generate image with {improved_image_prompt}") + program_execution_context.append(f"Failed to generate image with {improved_image_prompt}") async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"): yield result else: @@ -1175,7 +1175,7 @@ async def chat( yield result else: error_message = "Failed to generate diagram. Please try again later." - additional_context_for_llm_response.append( + program_execution_context.append( f"AI attempted to programmatically generate a diagram but failed due to a program issue. Generally, it is able to do so, but encountered a system issue this time. AI can suggest text description or rendering of the diagram or user can try again with a simpler prompt." ) @@ -1208,7 +1208,7 @@ async def chat( generated_images, generated_files, generated_excalidraw_diagram, - additional_context_for_llm_response, + program_execution_context, tracer, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 7d61752c..29c44d94 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1188,7 +1188,7 @@ def generate_chat_response( generated_images: List[str] = None, raw_generated_files: List[FileAttachment] = [], generated_excalidraw_diagram: str = None, - additional_context_for_llm_response: List[str] = [], + program_execution_context: List[str] = [], tracer: dict = {}, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables @@ -1280,7 +1280,7 @@ def generate_chat_response( generated_files=raw_generated_files, generated_images=generated_images, generated_excalidraw_diagram=generated_excalidraw_diagram, - additional_context_for_llm_response=additional_context_for_llm_response, + program_execution_context=program_execution_context, tracer=tracer, ) @@ -1307,7 +1307,7 @@ def generate_chat_response( generated_files=raw_generated_files, generated_images=generated_images, generated_excalidraw_diagram=generated_excalidraw_diagram, - additional_context_for_llm_response=additional_context_for_llm_response, + program_execution_context=program_execution_context, tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: @@ -1333,7 +1333,7 @@ def generate_chat_response( generated_files=raw_generated_files, generated_images=generated_images, generated_excalidraw_diagram=generated_excalidraw_diagram, - additional_context_for_llm_response=additional_context_for_llm_response, + program_execution_context=program_execution_context, tracer=tracer, ) From 4c4b7120c6d9fc177074e20a8a02c0ed1f560f05 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 8 Dec 2024 11:06:33 -0800 Subject: [PATCH 19/25] Use Khoj terrarium fork instead of building from official Cohere repo --- .github/workflows/run_evals.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run_evals.yml b/.github/workflows/run_evals.yml index 15800691..a8f4a17e 100644 --- a/.github/workflows/run_evals.yml +++ b/.github/workflows/run_evals.yml @@ -81,7 +81,7 @@ jobs: # upgrade pip python -m ensurepip --upgrade && python -m pip install --upgrade pip # install terrarium for code sandbox - git clone https://github.com/cohere-ai/cohere-terrarium.git && cd cohere-terrarium && npm install && mkdir pyodide_cache + git clone https://github.com/khoj-ai/terrarium.git && cd terrarium && npm install && mkdir pyodide_cache - name: ⬇️ Install Application run: | From 6940c6379bffd3967636c6345e856a661c2e7879 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 8 Dec 2024 11:11:13 -0800 Subject: [PATCH 20/25] Add sudo when running installations in order to install relevant packages add --legacy-peer-deps temporarily to see if it helps mitigate the issue --- .github/workflows/run_evals.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/run_evals.yml b/.github/workflows/run_evals.yml index a8f4a17e..91f0a721 100644 --- a/.github/workflows/run_evals.yml +++ b/.github/workflows/run_evals.yml @@ -76,12 +76,12 @@ jobs: DEBIAN_FRONTEND: noninteractive run: | # install postgres and other dependencies - apt update && apt install -y git python3-pip libegl1 sqlite3 libsqlite3-dev libsqlite3-0 ffmpeg libsm6 libxext6 - apt install -y postgresql postgresql-client && apt install -y postgresql-server-dev-14 + sudo apt update && sudo apt install -y git python3-pip libegl1 sqlite3 libsqlite3-dev libsqlite3-0 ffmpeg libsm6 libxext6 + sudo apt install -y postgresql postgresql-client && sudo apt install -y postgresql-server-dev-14 # upgrade pip python -m ensurepip --upgrade && python -m pip install --upgrade pip # install terrarium for code sandbox - git clone https://github.com/khoj-ai/terrarium.git && cd terrarium && npm install && mkdir pyodide_cache + git clone https://github.com/khoj-ai/terrarium.git && cd terrarium && npm install --legacy-peer-deps && mkdir pyodide_cache - name: ⬇️ Install Application run: | From efa23a8ad858dfe23b8422c5eb6bdddf450ec5e5 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 8 Dec 2024 11:30:17 -0800 Subject: [PATCH 21/25] Update validation requirements for online searches --- src/khoj/database/models/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 3f81fead..f6fb148e 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -50,16 +50,16 @@ class WebPage(PydanticBaseModel): class AnswerBox(PydanticBaseModel): - link: str - snippet: str + link: Optional[str] = None + snippet: Optional[str] = None title: str snippetHighlighted: List[str] class PeopleAlsoAsk(PydanticBaseModel): - link: str - question: str - snippet: str + link: Optional[str] = None + question: Optional[str] = None + snippet: Optional[str] = None title: str From 2af687d1c536a892201af8cc8f14c61736c7c7f7 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 8 Dec 2024 11:51:24 -0800 Subject: [PATCH 22/25] Allow snippetHighlighted to also be nullable --- src/khoj/database/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index f6fb148e..e98fb886 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -53,7 +53,7 @@ class AnswerBox(PydanticBaseModel): link: Optional[str] = None snippet: Optional[str] = None title: str - snippetHighlighted: List[str] + snippetHighlighted: Optional[List[str]] = None class PeopleAlsoAsk(PydanticBaseModel): From 7cd2855146f44774e74c8df031569037380c0fad Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 8 Dec 2024 12:23:17 -0800 Subject: [PATCH 23/25] Make attributes optional in the knowledge graph model --- src/khoj/database/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index e98fb886..82583bc0 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -64,7 +64,7 @@ class PeopleAlsoAsk(PydanticBaseModel): class KnowledgeGraph(PydanticBaseModel): - attributes: Dict[str, str] + attributes: Optional[Dict[str, str]] = None description: Optional[str] = None descriptionLink: Optional[str] = None descriptionSource: Optional[str] = None From 9c403d24e1ba8f506e8dd2aa83b111b12d0c1641 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 8 Dec 2024 13:03:05 -0800 Subject: [PATCH 24/25] Fix reference to directory in the eval workflow for starting terrarium --- .github/workflows/run_evals.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run_evals.yml b/.github/workflows/run_evals.yml index 91f0a721..dc8c89b7 100644 --- a/.github/workflows/run_evals.yml +++ b/.github/workflows/run_evals.yml @@ -113,7 +113,7 @@ jobs: khoj --anonymous-mode --non-interactive & # Start code sandbox - npm run dev --prefix cohere-terrarium & + npm run dev --prefix terrarium & # Wait for server to be ready timeout=120 From a2251f01eb7a6d5f3bd2473f77770d478c6b1c55 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 8 Dec 2024 13:27:33 -0800 Subject: [PATCH 25/25] Make result optional for code context, relevant when code execution was unsuccessful --- src/khoj/database/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 82583bc0..dea678d8 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -40,7 +40,7 @@ class CodeContextResult(PydanticBaseModel): class CodeContextData(PydanticBaseModel): code: str - result: CodeContextResult + result: Optional[CodeContextResult] = None class WebPage(PydanticBaseModel):