Merge pull request #992 from khoj-ai/features/allow-multi-outputs-in-chat

Currently, Khoj has terminal states with respect to what assets it outputs. We limit it to image, text, and excalidraw diagram. This limitation is unnecessary and provides undue constraints for creating more dynamic, diverse experiences. For instance, we may want the chat view to morph for document editing or generation, in which case having limited output modes would be a detriment.

This change allows us to emit generated assets and then continue on to more text generation in final response. It forces text response for all messages. It adds a new stream event, GENERATED_ASSETS, which holds content that the AI is emitting in response to the query.
This commit is contained in:
sabaimran 2024-12-08 14:19:05 -08:00 committed by GitHub
commit e3789aef49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 665 additions and 285 deletions

View file

@ -76,12 +76,12 @@ jobs:
DEBIAN_FRONTEND: noninteractive
run: |
# install postgres and other dependencies
apt update && apt install -y git python3-pip libegl1 sqlite3 libsqlite3-dev libsqlite3-0 ffmpeg libsm6 libxext6
apt install -y postgresql postgresql-client && apt install -y postgresql-server-dev-14
sudo apt update && sudo apt install -y git python3-pip libegl1 sqlite3 libsqlite3-dev libsqlite3-0 ffmpeg libsm6 libxext6
sudo apt install -y postgresql postgresql-client && sudo apt install -y postgresql-server-dev-14
# upgrade pip
python -m ensurepip --upgrade && python -m pip install --upgrade pip
# install terrarium for code sandbox
git clone https://github.com/cohere-ai/cohere-terrarium.git && cd cohere-terrarium && npm install && mkdir pyodide_cache
git clone https://github.com/khoj-ai/terrarium.git && cd terrarium && npm install --legacy-peer-deps && mkdir pyodide_cache
- name: ⬇️ Install Application
run: |
@ -113,7 +113,7 @@ jobs:
khoj --anonymous-mode --non-interactive &
# Start code sandbox
npm run dev --prefix cohere-terrarium &
npm run dev --prefix terrarium &
# Wait for server to be ready
timeout=120

View file

@ -1,4 +1,4 @@
import {ItemView, MarkdownRenderer, Scope, WorkspaceLeaf, request, requestUrl, setIcon, Platform} from 'obsidian';
import { ItemView, MarkdownRenderer, Scope, WorkspaceLeaf, request, requestUrl, setIcon, Platform } from 'obsidian';
import * as DOMPurify from 'dompurify';
import { KhojSetting } from 'src/settings';
import { KhojPaneView } from 'src/pane_view';
@ -27,6 +27,7 @@ interface ChatMessageState {
newResponseEl: HTMLElement | null;
loadingEllipsis: HTMLElement | null;
references: any;
generatedAssets: string;
rawResponse: string;
rawQuery: string;
isVoice: boolean;
@ -199,7 +200,7 @@ export class KhojChatView extends KhojPaneView {
// Get chat history from Khoj backend and set chat input state
let getChatHistorySucessfully = await this.getChatHistory(chatBodyEl);
let placeholderText : string = getChatHistorySucessfully ? this.startingMessage : "Configure Khoj to enable chat";
let placeholderText: string = getChatHistorySucessfully ? this.startingMessage : "Configure Khoj to enable chat";
chatInput.placeholder = placeholderText;
chatInput.disabled = !getChatHistorySucessfully;
@ -214,7 +215,7 @@ export class KhojChatView extends KhojPaneView {
});
}
startSpeechToText(event: KeyboardEvent | MouseEvent | TouchEvent, timeout=200) {
startSpeechToText(event: KeyboardEvent | MouseEvent | TouchEvent, timeout = 200) {
if (!this.keyPressTimeout) {
this.keyPressTimeout = setTimeout(async () => {
// Reset auto send voice message timer, UI if running
@ -320,7 +321,7 @@ export class KhojChatView extends KhojPaneView {
referenceButton.tabIndex = 0;
// Add event listener to toggle full reference on click
referenceButton.addEventListener('click', function() {
referenceButton.addEventListener('click', function () {
if (this.classList.contains("collapsed")) {
this.classList.remove("collapsed");
this.classList.add("expanded");
@ -375,7 +376,7 @@ export class KhojChatView extends KhojPaneView {
referenceButton.tabIndex = 0;
// Add event listener to toggle full reference on click
referenceButton.addEventListener('click', function() {
referenceButton.addEventListener('click', function () {
if (this.classList.contains("collapsed")) {
this.classList.remove("collapsed");
this.classList.add("expanded");
@ -427,7 +428,7 @@ export class KhojChatView extends KhojPaneView {
source.buffer = audioBuffer;
source.connect(context.destination);
source.start(0);
source.onended = function() {
source.onended = function () {
speechButton.removeChild(loader);
speechButton.disabled = false;
};
@ -485,12 +486,18 @@ export class KhojChatView extends KhojPaneView {
intentType?: string,
inferredQueries?: string[],
conversationId?: string,
images?: string[],
excalidrawDiagram?: string
) {
if (!message) return;
let chatMessageEl;
if (intentType?.includes("text-to-image") || intentType === "excalidraw") {
let imageMarkdown = this.generateImageMarkdown(message, intentType, inferredQueries, conversationId);
if (
intentType?.includes("text-to-image") ||
intentType === "excalidraw" ||
(images && images.length > 0) ||
excalidrawDiagram) {
let imageMarkdown = this.generateImageMarkdown(message, intentType ?? "", inferredQueries, conversationId, images, excalidrawDiagram);
chatMessageEl = this.renderMessage(chatEl, imageMarkdown, sender, dt);
} else {
chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
@ -510,7 +517,7 @@ export class KhojChatView extends KhojPaneView {
chatMessageBodyEl.appendChild(this.createReferenceSection(references));
}
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string): string {
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string, images?: string[], excalidrawDiagram?: string): string {
let imageMarkdown = "";
if (intentType === "text-to-image") {
imageMarkdown = `![](data:image/png;base64,${message})`;
@ -518,12 +525,23 @@ export class KhojChatView extends KhojPaneView {
imageMarkdown = `![](${message})`;
} else if (intentType === "text-to-image-v3") {
imageMarkdown = `![](data:image/webp;base64,${message})`;
} else if (intentType === "excalidraw") {
} else if (intentType === "excalidraw" || excalidrawDiagram) {
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}chat?conversationId=${conversationId}`;
imageMarkdown = redirectMessage;
} else if (images && images.length > 0) {
for (let image of images) {
if (image.startsWith("https://")) {
imageMarkdown += `![](${image})\n\n`;
} else {
imageMarkdown += `![](data:image/png;base64,${image})\n\n`;
}
if (inferredQueries) {
}
imageMarkdown += `${message}`;
}
if (images?.length === 0 && inferredQueries) {
imageMarkdown += "\n\n**Inferred Query**:";
for (let inferredQuery of inferredQueries) {
imageMarkdown += `\n\n${inferredQuery}`;
@ -768,10 +786,10 @@ export class KhojChatView extends KhojPaneView {
let editConversationTitleInputEl = this.contentEl.createEl('input');
editConversationTitleInputEl.classList.add("conversation-title-input");
editConversationTitleInputEl.value = conversationTitle;
editConversationTitleInputEl.addEventListener('click', function(event) {
editConversationTitleInputEl.addEventListener('click', function (event) {
event.stopPropagation();
});
editConversationTitleInputEl.addEventListener('keydown', function(event) {
editConversationTitleInputEl.addEventListener('keydown', function (event) {
if (event.key === "Enter") {
event.preventDefault();
editConversationTitleSaveButtonEl.click();
@ -890,9 +908,11 @@ export class KhojChatView extends KhojPaneView {
chatLog.intent?.type,
chatLog.intent?.["inferred-queries"],
chatBodyEl.dataset.conversationId ?? "",
chatLog.images,
chatLog.excalidrawDiagram,
);
// push the user messages to the chat history
if(chatLog.by === "you"){
if (chatLog.by === "you") {
this.userMessages.push(chatLog.message);
}
});
@ -922,15 +942,15 @@ export class KhojChatView extends KhojPaneView {
try {
let jsonChunk = JSON.parse(rawChunk);
if (!jsonChunk.type)
jsonChunk = {type: 'message', data: jsonChunk};
jsonChunk = { type: 'message', data: jsonChunk };
return jsonChunk;
} catch (e) {
return {type: 'message', data: rawChunk};
return { type: 'message', data: rawChunk };
}
} else if (rawChunk.length > 0) {
return {type: 'message', data: rawChunk};
return { type: 'message', data: rawChunk };
}
return {type: '', data: ''};
return { type: '', data: '' };
}
processMessageChunk(rawChunk: string): void {
@ -941,6 +961,11 @@ export class KhojChatView extends KhojPaneView {
console.log(`status: ${chunk.data}`);
const statusMessage = chunk.data;
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, statusMessage, this.chatMessageState.loadingEllipsis, false);
} else if (chunk.type === 'generated_assets') {
const generatedAssets = chunk.data;
const imageData = this.handleImageResponse(generatedAssets, this.chatMessageState.rawResponse);
this.chatMessageState.generatedAssets = imageData;
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, imageData, this.chatMessageState.loadingEllipsis, false);
} else if (chunk.type === 'start_llm_response') {
console.log("Started streaming", new Date());
} else if (chunk.type === 'end_llm_response') {
@ -963,9 +988,10 @@ export class KhojChatView extends KhojPaneView {
rawResponse: "",
rawQuery: liveQuery,
isVoice: false,
generatedAssets: "",
};
} else if (chunk.type === "references") {
this.chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext};
this.chatMessageState.references = { "notes": chunk.data.context, "online": chunk.data.onlineContext };
} else if (chunk.type === 'message') {
const chunkData = chunk.data;
if (typeof chunkData === 'object' && chunkData !== null) {
@ -978,17 +1004,17 @@ export class KhojChatView extends KhojPaneView {
this.handleJsonResponse(jsonData);
} catch (e) {
this.chatMessageState.rawResponse += chunkData;
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis);
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse + this.chatMessageState.generatedAssets, this.chatMessageState.loadingEllipsis);
}
} else {
this.chatMessageState.rawResponse += chunkData;
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis);
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse + this.chatMessageState.generatedAssets, this.chatMessageState.loadingEllipsis);
}
}
}
handleJsonResponse(jsonData: any): void {
if (jsonData.image || jsonData.detail) {
if (jsonData.image || jsonData.detail || jsonData.images || jsonData.excalidrawDiagram) {
this.chatMessageState.rawResponse = this.handleImageResponse(jsonData, this.chatMessageState.rawResponse);
} else if (jsonData.response) {
this.chatMessageState.rawResponse = jsonData.response;
@ -1234,11 +1260,11 @@ export class KhojChatView extends KhojPaneView {
const recordingConfig = { mimeType: 'audio/webm' };
this.mediaRecorder = new MediaRecorder(stream, recordingConfig);
this.mediaRecorder.addEventListener("dataavailable", function(event) {
this.mediaRecorder.addEventListener("dataavailable", function (event) {
if (event.data.size > 0) audioChunks.push(event.data);
});
this.mediaRecorder.addEventListener("stop", async function() {
this.mediaRecorder.addEventListener("stop", async function () {
const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
await sendToServer(audioBlob);
});
@ -1368,7 +1394,22 @@ export class KhojChatView extends KhojPaneView {
if (inferredQuery) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
} else if (imageJson.images) {
// If response has images field, response is a list of generated images.
imageJson.images.forEach((image: any) => {
if (image.startsWith("http")) {
rawResponse += `![generated_image](${image})\n\n`;
} else {
rawResponse += `![generated_image](data:image/png;base64,${image})\n\n`;
}
});
} else if (imageJson.excalidrawDiagram) {
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}`;
rawResponse += redirectMessage;
}
// If response has detail field, response is an error message.
if (imageJson.detail) rawResponse += imageJson.detail;
@ -1407,7 +1448,7 @@ export class KhojChatView extends KhojPaneView {
referenceExpandButton.classList.add("reference-expand-button");
referenceExpandButton.innerHTML = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.addEventListener('click', function() {
referenceExpandButton.addEventListener('click', function () {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded");

View file

@ -82,7 +82,8 @@ If your plugin does not need CSS, delete this file.
}
/* color chat bubble by khoj blue */
.khoj-chat-message-text.khoj {
border: 1px solid var(--khoj-sun);
border-top: 1px solid var(--khoj-sun);
border-radius: 0px;
margin-left: auto;
white-space: pre-line;
}
@ -104,8 +105,9 @@ If your plugin does not need CSS, delete this file.
}
/* color chat bubble by you dark grey */
.khoj-chat-message-text.you {
border: 1px solid var(--color-accent);
color: var(--text-normal);
margin-right: auto;
background-color: var(--background-modifier-cover);
}
/* add right protrusion to you chat bubble */
.khoj-chat-message-text.you:after {

View file

@ -1,3 +1,4 @@
import { AttachedFileText } from "../components/chatInputArea/chatInputArea";
import {
CodeContext,
Context,
@ -16,6 +17,12 @@ export interface MessageMetadata {
turnId: string;
}
export interface GeneratedAssetsData {
images: string[];
excalidrawDiagram: string;
files: AttachedFileText[];
}
export interface ResponseWithIntent {
intentType: string;
response: string;
@ -84,6 +91,8 @@ export function processMessageChunk(
if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext, codeContext };
console.log(`chunk type: ${chunk.type}`);
if (chunk.type === "status") {
console.log(`status: ${chunk.data}`);
const statusMessage = chunk.data as string;
@ -98,6 +107,20 @@ export function processMessageChunk(
} else if (chunk.type === "metadata") {
const messageMetadata = chunk.data as MessageMetadata;
currentMessage.turnId = messageMetadata.turnId;
} else if (chunk.type === "generated_assets") {
const generatedAssets = chunk.data as GeneratedAssetsData;
if (generatedAssets.images) {
currentMessage.generatedImages = generatedAssets.images;
}
if (generatedAssets.excalidrawDiagram) {
currentMessage.generatedExcalidrawDiagram = generatedAssets.excalidrawDiagram;
}
if (generatedAssets.files) {
currentMessage.generatedFiles = generatedAssets.files;
}
} else if (chunk.type === "message") {
const chunkData = chunk.data;
// Here, handle if the response is a JSON response with an image, but the intentType is excalidraw

View file

@ -54,6 +54,12 @@ function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
const lastIndex = props.trainOfThought.length - 1;
const [collapsed, setCollapsed] = useState(props.completed);
useEffect(() => {
if (props.completed) {
setCollapsed(true);
}
}, [props.completed]);
return (
<div
className={`${!collapsed ? styles.trainOfThought + " shadow-sm" : ""}`}
@ -410,6 +416,9 @@ export default function ChatHistory(props: ChatHistoryProps) {
"inferred-queries": message.inferredQueries || [],
},
conversationId: props.conversationId,
images: message.generatedImages,
queryFiles: message.generatedFiles,
excalidrawDiagram: message.generatedExcalidrawDiagram,
turnId: messageTurnId,
}}
conversationId={props.conversationId}

View file

@ -77,6 +77,21 @@ div.imageWrapper img {
border-radius: 8px;
}
div.khoj div.imageWrapper img {
height: 512px;
}
div.khoj div.imageWrapper {
flex: 1 1 auto;
}
div.khoj div.imagesContainer {
display: flex;
flex-wrap: wrap;
flex-direction: row;
overflow-x: hidden;
}
div.chatMessageContainer > img {
width: auto;
height: auto;
@ -178,4 +193,9 @@ div.trainOfThoughtElement ul {
div.youfullHistory {
max-width: 90%;
}
div.khoj div.imageWrapper img {
width: 100%;
height: auto;
}
}

View file

@ -163,6 +163,7 @@ export interface SingleChatMessage {
conversationId: string;
turnId?: string;
queryFiles?: AttachedFileText[];
excalidrawDiagram?: string;
}
export interface StreamMessage {
@ -180,6 +181,10 @@ export interface StreamMessage {
inferredQueries?: string[];
turnId?: string;
queryFiles?: AttachedFileText[];
excalidrawDiagram?: string;
generatedFiles?: AttachedFileText[];
generatedImages?: string[];
generatedExcalidrawDiagram?: string;
}
export interface ChatHistoryData {
@ -264,6 +269,9 @@ interface ChatMessageProps {
onDeleteMessage: (turnId?: string) => void;
conversationId: string;
turnId?: string;
generatedImage?: string;
excalidrawDiagram?: string;
generatedFiles?: AttachedFileText[];
}
interface TrainOfThoughtProps {
@ -389,9 +397,8 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
// Prepare initial message for rendering
let message = props.chatMessage.message;
if (props.chatMessage.intent && props.chatMessage.intent.type == "excalidraw") {
message = props.chatMessage.intent["inferred-queries"][0];
setExcalidrawData(props.chatMessage.message);
if (props.chatMessage.excalidrawDiagram) {
setExcalidrawData(props.chatMessage.excalidrawDiagram);
}
// Replace LaTeX delimiters with placeholders
@ -401,27 +408,6 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
.replace(/\\\[/g, "LEFTBRACKET")
.replace(/\\\]/g, "RIGHTBRACKET");
const intentTypeHandlers = {
"text-to-image": (msg: string) => `![generated image](data:image/png;base64,${msg})`,
"text-to-image2": (msg: string) => `![generated image](${msg})`,
"text-to-image-v3": (msg: string) =>
`![generated image](data:image/webp;base64,${msg})`,
excalidraw: (msg: string) => msg,
};
// Handle intent-specific rendering
if (props.chatMessage.intent) {
const { type, "inferred-queries": inferredQueries } = props.chatMessage.intent;
if (type in intentTypeHandlers) {
message = intentTypeHandlers[type as keyof typeof intentTypeHandlers](message);
}
if (type.includes("text-to-image") && inferredQueries?.length > 0) {
message += `\n\n${inferredQueries[0]}`;
}
}
// Replace file links with base64 data
message = renderCodeGenImageInline(message, props.chatMessage.codeContext);

View file

@ -303,17 +303,10 @@ class ConversationAdmin(unfold_admin.ModelAdmin):
modified_log = conversation.conversation_log
chat_log = modified_log.get("chat", [])
for idx, log in enumerate(chat_log):
if (
log["by"] == "khoj"
and log["intent"]
and log["intent"]["type"]
and (
log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
)
):
log["message"] = "inline image redacted for space"
if log["by"] == "khoj" and log["images"]:
log["images"] = ["inline image redacted for space"]
chat_log[idx] = log
modified_log["chat"] = chat_log
writer.writerow(

View file

@ -0,0 +1,85 @@
# Made manually by sabaimran for use by Django 5.0.9 on 2024-12-01 16:59
from django.db import migrations, models
# This script was written alongside when Pydantic validation was added to the Conversation conversation_log field.
def migrate_generated_assets(apps, schema_editor):
Conversation = apps.get_model("database", "Conversation")
# Process conversations in chunks
for conversation in Conversation.objects.iterator():
try:
meta_log = conversation.conversation_log
modified = False
for chat in meta_log.get("chat", []):
intent_type = chat.get("intent", {}).get("type")
if intent_type and chat["by"] == "khoj":
if intent_type and "text-to-image" in intent_type:
# Migrate the generated image to the new format
chat["images"] = [chat.get("message")]
chat["message"] = chat["intent"]["inferred-queries"][0]
modified = True
if intent_type and "excalidraw" in intent_type:
# Migrate the generated excalidraw to the new format
chat["excalidrawDiagram"] = chat.get("message")
chat["message"] = chat["intent"]["inferred-queries"][0]
modified = True
# Only save if changes were made
if modified:
conversation.conversation_log = meta_log
conversation.save()
except Exception as e:
print(f"Error processing conversation {conversation.id}: {str(e)}")
continue
def reverse_migration(apps, schema_editor):
Conversation = apps.get_model("database", "Conversation")
# Process conversations in chunks
for conversation in Conversation.objects.iterator():
try:
meta_log = conversation.conversation_log
modified = False
for chat in meta_log.get("chat", []):
intent_type = chat.get("intent", {}).get("type")
if intent_type and chat["by"] == "khoj":
if intent_type and "text-to-image" in intent_type:
# Migrate the generated image back to the old format
chat["message"] = chat.get("images", [])[0]
chat.pop("images", None)
modified = True
if intent_type and "excalidraw" in intent_type:
# Migrate the generated excalidraw back to the old format
chat["message"] = chat.get("excalidrawDiagram")
chat.pop("excalidrawDiagram", None)
modified = True
# Only save if changes were made
if modified:
conversation.conversation_log = meta_log
conversation.save()
except Exception as e:
print(f"Error processing conversation {conversation.id}: {str(e)}")
continue
class Migration(migrations.Migration):
dependencies = [
("database", "0074_alter_conversation_title"),
]
operations = [
migrations.RunPython(migrate_generated_assets, reverse_migration),
]

View file

@ -1,7 +1,9 @@
import logging
import os
import re
import uuid
from random import choice
from typing import Dict, List, Optional, Union
from django.contrib.auth.models import AbstractUser
from django.contrib.postgres.fields import ArrayField
@ -11,9 +13,109 @@ from django.db.models.signals import pre_save
from django.dispatch import receiver
from pgvector.django import VectorField
from phonenumber_field.modelfields import PhoneNumberField
from pydantic import BaseModel as PydanticBaseModel
from pydantic import Field
logger = logging.getLogger(__name__)
class BaseModel(models.Model):
# Pydantic models for type Chat Message validation
class Context(PydanticBaseModel):
compiled: str
file: str
class CodeContextFile(PydanticBaseModel):
filename: str
b64_data: str
class CodeContextResult(PydanticBaseModel):
success: bool
output_files: List[CodeContextFile]
std_out: str
std_err: str
code_runtime: int
class CodeContextData(PydanticBaseModel):
code: str
result: Optional[CodeContextResult] = None
class WebPage(PydanticBaseModel):
link: str
query: Optional[str] = None
snippet: str
class AnswerBox(PydanticBaseModel):
link: Optional[str] = None
snippet: Optional[str] = None
title: str
snippetHighlighted: Optional[List[str]] = None
class PeopleAlsoAsk(PydanticBaseModel):
link: Optional[str] = None
question: Optional[str] = None
snippet: Optional[str] = None
title: str
class KnowledgeGraph(PydanticBaseModel):
attributes: Optional[Dict[str, str]] = None
description: Optional[str] = None
descriptionLink: Optional[str] = None
descriptionSource: Optional[str] = None
imageUrl: Optional[str] = None
title: str
type: Optional[str] = None
class OrganicContext(PydanticBaseModel):
snippet: str
title: str
link: str
class OnlineContext(PydanticBaseModel):
webpages: Optional[Union[WebPage, List[WebPage]]] = None
answerBox: Optional[AnswerBox] = None
peopleAlsoAsk: Optional[List[PeopleAlsoAsk]] = None
knowledgeGraph: Optional[KnowledgeGraph] = None
organicContext: Optional[List[OrganicContext]] = None
class Intent(PydanticBaseModel):
type: str
query: str
memory_type: str = Field(alias="memory-type")
inferred_queries: Optional[List[str]] = Field(default=None, alias="inferred-queries")
class TrainOfThought(PydanticBaseModel):
type: str
data: str
class ChatMessage(PydanticBaseModel):
message: str
trainOfThought: List[TrainOfThought] = []
context: List[Context] = []
onlineContext: Dict[str, OnlineContext] = {}
codeContext: Dict[str, CodeContextData] = {}
created: str
images: Optional[List[str]] = None
queryFiles: Optional[List[Dict]] = None
excalidrawDiagram: Optional[List[Dict]] = None
by: str
turnId: Optional[str] = None
intent: Optional[Intent] = None
automationId: Optional[str] = None
class DbBaseModel(models.Model):
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
@ -21,7 +123,7 @@ class BaseModel(models.Model):
abstract = True
class ClientApplication(BaseModel):
class ClientApplication(DbBaseModel):
name = models.CharField(max_length=200)
client_id = models.CharField(max_length=200)
client_secret = models.CharField(max_length=200)
@ -67,7 +169,7 @@ class KhojApiUser(models.Model):
accessed_at = models.DateTimeField(null=True, default=None)
class Subscription(BaseModel):
class Subscription(DbBaseModel):
class Type(models.TextChoices):
TRIAL = "trial"
STANDARD = "standard"
@ -79,13 +181,13 @@ class Subscription(BaseModel):
enabled_trial_at = models.DateTimeField(null=True, default=None, blank=True)
class OpenAIProcessorConversationConfig(BaseModel):
class OpenAIProcessorConversationConfig(DbBaseModel):
name = models.CharField(max_length=200)
api_key = models.CharField(max_length=200)
api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True)
class ChatModelOptions(BaseModel):
class ChatModelOptions(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
@ -103,12 +205,12 @@ class ChatModelOptions(BaseModel):
)
class VoiceModelOption(BaseModel):
class VoiceModelOption(DbBaseModel):
model_id = models.CharField(max_length=200)
name = models.CharField(max_length=200)
class Agent(BaseModel):
class Agent(DbBaseModel):
class StyleColorTypes(models.TextChoices):
BLUE = "blue"
GREEN = "green"
@ -208,7 +310,7 @@ class Agent(BaseModel):
super().save(*args, **kwargs)
class ProcessLock(BaseModel):
class ProcessLock(DbBaseModel):
class Operation(models.TextChoices):
INDEX_CONTENT = "index_content"
SCHEDULED_JOB = "scheduled_job"
@ -231,24 +333,24 @@ def verify_agent(sender, instance, **kwargs):
raise ValidationError(f"A private Agent with the name {instance.name} already exists.")
class NotionConfig(BaseModel):
class NotionConfig(DbBaseModel):
token = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class GithubConfig(BaseModel):
class GithubConfig(DbBaseModel):
pat_token = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class GithubRepoConfig(BaseModel):
class GithubRepoConfig(DbBaseModel):
name = models.CharField(max_length=200)
owner = models.CharField(max_length=200)
branch = models.CharField(max_length=200)
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
class WebScraper(BaseModel):
class WebScraper(DbBaseModel):
class WebScraperType(models.TextChoices):
FIRECRAWL = "Firecrawl"
OLOSTEP = "Olostep"
@ -321,7 +423,7 @@ class WebScraper(BaseModel):
super().save(*args, **kwargs)
class ServerChatSettings(BaseModel):
class ServerChatSettings(DbBaseModel):
chat_default = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
)
@ -333,35 +435,35 @@ class ServerChatSettings(BaseModel):
)
class LocalOrgConfig(BaseModel):
class LocalOrgConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalMarkdownConfig(BaseModel):
class LocalMarkdownConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalPdfConfig(BaseModel):
class LocalPdfConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalPlaintextConfig(BaseModel):
class LocalPlaintextConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class SearchModelConfig(BaseModel):
class SearchModelConfig(DbBaseModel):
class ModelType(models.TextChoices):
TEXT = "text"
@ -393,7 +495,7 @@ class SearchModelConfig(BaseModel):
bi_encoder_confidence_threshold = models.FloatField(default=0.18)
class TextToImageModelConfig(BaseModel):
class TextToImageModelConfig(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
STABILITYAI = "stability-ai"
@ -430,7 +532,7 @@ class TextToImageModelConfig(BaseModel):
super().save(*args, **kwargs)
class SpeechToTextModelOptions(BaseModel):
class SpeechToTextModelOptions(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
@ -439,22 +541,22 @@ class SpeechToTextModelOptions(BaseModel):
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
class UserConversationConfig(BaseModel):
class UserConversationConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
class UserVoiceModelConfig(BaseModel):
class UserVoiceModelConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
class UserTextToImageModelConfig(BaseModel):
class UserTextToImageModelConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
class Conversation(BaseModel):
class Conversation(DbBaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
@ -468,8 +570,39 @@ class Conversation(BaseModel):
file_filters = models.JSONField(default=list)
id = models.UUIDField(default=uuid.uuid4, editable=False, unique=True, primary_key=True, db_index=True)
def clean(self):
# Validate conversation_log structure
try:
messages = self.conversation_log.get("chat", [])
for msg in messages:
ChatMessage.model_validate(msg)
except Exception as e:
raise ValidationError(f"Invalid conversation_log format: {str(e)}")
class PublicConversation(BaseModel):
def save(self, *args, **kwargs):
self.clean()
super().save(*args, **kwargs)
@property
def messages(self) -> List[ChatMessage]:
"""Type-hinted accessor for conversation messages"""
validated_messages = []
for msg in self.conversation_log.get("chat", []):
try:
# Clean up inferred queries if they contain None
if msg.get("intent") and msg["intent"].get("inferred-queries"):
msg["intent"]["inferred-queries"] = [
q for q in msg["intent"]["inferred-queries"] if q is not None and isinstance(q, str)
]
msg["message"] = str(msg.get("message", ""))
validated_messages.append(ChatMessage.model_validate(msg))
except ValidationError as e:
logger.warning(f"Skipping invalid message in conversation: {e}")
continue
return validated_messages
class PublicConversation(DbBaseModel):
source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)
slug = models.CharField(max_length=200, default=None, null=True, blank=True)
@ -499,12 +632,12 @@ def verify_public_conversation(sender, instance, **kwargs):
instance.slug = slug
class ReflectiveQuestion(BaseModel):
class ReflectiveQuestion(DbBaseModel):
question = models.CharField(max_length=500)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
class Entry(BaseModel):
class Entry(DbBaseModel):
class EntryType(models.TextChoices):
IMAGE = "image"
PDF = "pdf"
@ -541,7 +674,7 @@ class Entry(BaseModel):
raise ValidationError("An Entry cannot be associated with both a user and an agent.")
class FileObject(BaseModel):
class FileObject(DbBaseModel):
# Same as Entry but raw will be a much larger string
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
raw_text = models.TextField()
@ -549,7 +682,7 @@ class FileObject(BaseModel):
agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
class EntryDates(BaseModel):
class EntryDates(DbBaseModel):
date = models.DateField()
entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")
@ -559,12 +692,12 @@ class EntryDates(BaseModel):
]
class UserRequests(BaseModel):
class UserRequests(DbBaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
slug = models.CharField(max_length=200)
class DataStore(BaseModel):
class DataStore(DbBaseModel):
key = models.CharField(max_length=200, unique=True)
value = models.JSONField(default=dict)
private = models.BooleanField(default=False)

View file

@ -1,6 +1,6 @@
import logging
from datetime import datetime, timedelta
from typing import Dict, Optional
from typing import Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
@ -23,7 +23,7 @@ from khoj.utils.helpers import (
is_none_or_empty,
truncate_code_context,
)
from khoj.utils.rawconfig import LocationData
from khoj.utils.rawconfig import FileAttachment, LocationData
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@ -55,7 +55,7 @@ def extract_questions_anthropic(
[
f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type")
if chat["by"] == "khoj"
]
)
@ -157,6 +157,10 @@ def converse_anthropic(
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: Optional[str] = None,
program_execution_context: Optional[List[str]] = None,
tracer: dict = {},
):
"""
@ -217,6 +221,10 @@ def converse_anthropic(
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
query_files=query_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_files=generated_files,
generated_images=generated_images,
program_execution_context=program_execution_context,
)
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)

View file

@ -1,6 +1,6 @@
import logging
from datetime import datetime, timedelta
from typing import Dict, Optional
from typing import Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
@ -23,7 +23,7 @@ from khoj.utils.helpers import (
is_none_or_empty,
truncate_code_context,
)
from khoj.utils.rawconfig import LocationData
from khoj.utils.rawconfig import FileAttachment, LocationData
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ def extract_questions_gemini(
[
f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type")
if chat["by"] == "khoj"
]
)
@ -167,6 +167,10 @@ def converse_gemini(
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: Optional[str] = None,
program_execution_context: List[str] = None,
tracer={},
):
"""
@ -228,6 +232,10 @@ def converse_gemini(
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE,
query_files=query_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_files=generated_files,
generated_images=generated_images,
program_execution_context=program_execution_context,
)
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)

View file

@ -28,7 +28,7 @@ from khoj.utils.helpers import (
is_promptrace_enabled,
truncate_code_context,
)
from khoj.utils.rawconfig import LocationData
from khoj.utils.rawconfig import FileAttachment, LocationData
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@ -69,7 +69,7 @@ def extract_questions_offline(
if use_history:
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type"):
if chat["by"] == "khoj":
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"Khoj: {chat['message']}\n\n"
@ -164,6 +164,8 @@ def converse_offline(
user_name: str = None,
agent: Agent = None,
query_files: str = None,
generated_files: List[FileAttachment] = None,
additional_context: List[str] = None,
tracer: dict = {},
) -> Union[ThreadedGenerator, Iterator[str]]:
"""
@ -231,6 +233,8 @@ def converse_offline(
tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.OFFLINE,
query_files=query_files,
generated_files=generated_files,
program_execution_context=additional_context,
)
logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}")

View file

@ -1,6 +1,6 @@
import logging
from datetime import datetime, timedelta
from typing import Dict, Optional
from typing import Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
@ -22,7 +22,7 @@ from khoj.utils.helpers import (
is_none_or_empty,
truncate_code_context,
)
from khoj.utils.rawconfig import LocationData
from khoj.utils.rawconfig import FileAttachment, LocationData
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@ -157,6 +157,10 @@ def converse(
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: Optional[str] = None,
program_execution_context: List[str] = None,
tracer: dict = {},
):
"""
@ -219,6 +223,10 @@ def converse(
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI,
query_files=query_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_files=generated_files,
generated_images=generated_images,
program_execution_context=program_execution_context,
)
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")

View file

@ -178,6 +178,18 @@ Improved Prompt:
""".strip()
)
generated_image_attachment = PromptTemplate.from_template(
f"""
Here is the image you generated based on my query. You can follow-up with a general response to my query. Limit to 1-2 sentences.
""".strip()
)
generated_diagram_attachment = PromptTemplate.from_template(
f"""
I've successfully created a diagram based on the user's query. The diagram will automatically be shared with the user. I can follow-up with a general response or summary. Limit to 1-2 sentences.
""".strip()
)
## Diagram Generation
## --
@ -1029,6 +1041,12 @@ A:
""".strip()
)
additional_program_context = PromptTemplate.from_template(
"""
Here are some additional results from the query execution:
{context}
""".strip()
)
personality_prompt_safety_expert_lax = PromptTemplate.from_template(
"""

View file

@ -154,7 +154,7 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
chat_history += f"{agent_name}: {chat['message']}\n\n"
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
elif chat["by"] == "khoj" and chat.get("images"):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: [generated image redacted for space]\n"
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
@ -213,6 +213,7 @@ class ChatEvent(Enum):
END_LLM_RESPONSE = "end_llm_response"
MESSAGE = "message"
REFERENCES = "references"
GENERATED_ASSETS = "generated_assets"
STATUS = "status"
METADATA = "metadata"
USAGE = "usage"
@ -225,7 +226,6 @@ def message_to_log(
user_message_metadata={},
khoj_message_metadata={},
conversation_log=[],
train_of_thought=[],
):
"""Create json logs from messages, metadata for conversation log"""
default_khoj_message_metadata = {
@ -234,6 +234,10 @@ def message_to_log(
}
khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Filter out any fields that are set to None
user_message_metadata = {k: v for k, v in user_message_metadata.items() if v is not None}
khoj_message_metadata = {k: v for k, v in khoj_message_metadata.items() if v is not None}
# Create json log from Human's message
human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
@ -261,21 +265,21 @@ def save_to_conversation_log(
automation_id: str = None,
query_images: List[str] = None,
raw_query_files: List[FileAttachment] = [],
generated_images: List[str] = [],
raw_generated_files: List[FileAttachment] = [],
generated_excalidraw_diagram: str = None,
train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {},
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
turn_id = tracer.get("mid") or str(uuid.uuid4())
updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata={
"created": user_message_time,
"images": query_images,
"turnId": turn_id,
"queryFiles": [file.model_dump(mode="json") for file in raw_query_files],
},
khoj_message_metadata={
user_message_metadata = {"created": user_message_time, "images": query_images, "turnId": turn_id}
if raw_query_files and len(raw_query_files) > 0:
user_message_metadata["queryFiles"] = [file.model_dump(mode="json") for file in raw_query_files]
khoj_message_metadata = {
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
"onlineContext": online_results,
@ -283,9 +287,19 @@ def save_to_conversation_log(
"automationId": automation_id,
"trainOfThought": train_of_thought,
"turnId": turn_id,
},
"images": generated_images,
"queryFiles": [file.model_dump(mode="json") for file in raw_generated_files],
}
if generated_excalidraw_diagram:
khoj_message_metadata["excalidrawDiagram"] = generated_excalidraw_diagram
updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata=user_message_metadata,
khoj_message_metadata=khoj_message_metadata,
conversation_log=meta_log.get("chat", []),
train_of_thought=train_of_thought,
)
ConversationAdapters.save_conversation(
user,
@ -303,13 +317,13 @@ def save_to_conversation_log(
Saved Conversation Turn
You ({user.username}): "{q}"
Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}"
Khoj: "{chat_response}"
""".strip()
)
def construct_structured_message(
message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str
message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str = None
):
"""
Format messages into appropriate multimedia format for supported chat model types
@ -327,6 +341,7 @@ def construct_structured_message(
constructed_messages.append({"type": "text", "text": attached_file_context})
if vision_enabled and images:
for image in images:
if image.startswith("https://"):
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
return constructed_messages
@ -365,6 +380,10 @@ def generate_chatml_messages_with_context(
model_type="",
context_message="",
query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: str = None,
program_execution_context: List[str] = [],
):
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
# Set max prompt size from user config or based on pre-configured for model and machine specs
@ -384,6 +403,7 @@ def generate_chatml_messages_with_context(
message_attached_files = ""
chat_message = chat.get("message")
role = "user" if chat["by"] == "you" else "assistant"
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
chat_message = chat["intent"].get("inferred-queries")[0]
@ -404,7 +424,7 @@ def generate_chatml_messages_with_context(
query_files_dict[file["name"]] = file["content"]
message_attached_files = gather_raw_query_files(query_files_dict)
chatml_messages.append(ChatMessage(content=message_attached_files, role="user"))
chatml_messages.append(ChatMessage(content=message_attached_files, role=role))
if not is_none_or_empty(chat.get("onlineContext")):
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
@ -413,9 +433,19 @@ def generate_chatml_messages_with_context(
reconstructed_context_message = ChatMessage(content=message_context, role="user")
chatml_messages.insert(0, reconstructed_context_message)
role = "user" if chat["by"] == "you" else "assistant"
if chat.get("images"):
if role == "assistant":
# Issue: the assistant role cannot accept an image as a message content, so send it in a separate user message.
file_attachment_message = construct_structured_message(
message=prompts.generated_image_attachment.format(),
images=chat.get("images"),
model_type=model_type,
vision_enabled=vision_enabled,
)
chatml_messages.append(ChatMessage(content=file_attachment_message, role="user"))
else:
message_content = construct_structured_message(
chat_message, chat.get("images"), model_type, vision_enabled, attached_file_context=query_files
chat_message, chat.get("images"), model_type, vision_enabled
)
reconstructed_message = ChatMessage(content=message_content, role=role)
@ -425,6 +455,7 @@ def generate_chatml_messages_with_context(
break
messages = []
if not is_none_or_empty(user_message):
messages.append(
ChatMessage(
@ -437,6 +468,31 @@ def generate_chatml_messages_with_context(
if not is_none_or_empty(context_message):
messages.append(ChatMessage(content=context_message, role="user"))
if generated_images:
messages.append(
ChatMessage(
content=construct_structured_message(
prompts.generated_image_attachment.format(), generated_images, model_type, vision_enabled
),
role="user",
)
)
if generated_files:
message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
messages.append(ChatMessage(content=message_attached_files, role="assistant"))
if generated_excalidraw_diagram:
messages.append(ChatMessage(content=prompts.generated_diagram_attachment.format(), role="assistant"))
if program_execution_context:
messages.append(
ChatMessage(
content=prompts.additional_program_context.format(context="\n".join(program_execution_context)),
role="assistant",
)
)
if len(chatml_messages) > 0:
messages += chatml_messages

View file

@ -12,7 +12,7 @@ from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
from khoj.routers.storage import upload_image
from khoj.utils import state
from khoj.utils.helpers import ImageIntentType, convert_image_to_webp, timer
from khoj.utils.helpers import convert_image_to_webp, timer
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
@ -34,14 +34,13 @@ async def text_to_image(
status_code = 200
image = None
image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
yield image_url or image, status_code, message, intent_type.value
yield image_url or image, status_code, message
return
text2image_model = text_to_image_config.model_name
@ -50,8 +49,8 @@ async def text_to_image(
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
elif chat["by"] == "khoj" and chat.get("images"):
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
if send_status_func:
@ -92,31 +91,29 @@ async def text_to_image(
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore
yield image_url or image, status_code, message, intent_type.value
yield image_url or image, status_code, message
return
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed using OpenAI" # type: ignore
status_code = e.status_code # type: ignore
yield image_url or image, status_code, message, intent_type.value
yield image_url or image, status_code, message
return
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error."
status_code = 502
yield image_url or image, status_code, message, intent_type.value
yield image_url or image, status_code, message
return
# Decide how to store the generated image
with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid)
if image_url:
intent_type = ImageIntentType.TEXT_TO_IMAGE2
else:
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
if not image_url:
image = base64.b64encode(webp_image_bytes).decode("utf-8")
yield image_url or image, status_code, image_prompt, intent_type.value
yield image_url or image, status_code, image_prompt
def generate_image_with_openai(

View file

@ -77,6 +77,7 @@ from khoj.utils.helpers import (
)
from khoj.utils.rawconfig import (
ChatRequestBody,
FileAttachment,
FileFilterRequest,
FilesFilterRequest,
LocationData,
@ -770,6 +771,11 @@ async def chat(
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
attached_file_context = gather_raw_query_files(query_files)
generated_images: List[str] = []
generated_files: List[FileAttachment] = []
generated_excalidraw_diagram: str = None
program_execution_context: List[str] = []
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
chosen_io = await aget_data_sources_and_output_format(
q,
@ -875,21 +881,17 @@ async def chat(
async for result in send_llm_response(response, tracer.get("usage")):
yield result
await sync_to_async(save_to_conversation_log)(
q,
response_log,
user,
meta_log,
user_message_time,
intent_type="summarize",
client_application=request.user.client_app,
conversation_id=conversation_id,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
summarized_document = FileAttachment(
name="Summarized Document",
content=response_log,
type="text/plain",
size=len(response_log.encode("utf-8")),
)
return
async for result in send_event(ChatEvent.GENERATED_ASSETS, {"files": [summarized_document.model_dump()]}):
yield result
generated_files.append(summarized_document)
custom_filters = []
if conversation_commands == [ConversationCommand.Help]:
@ -1078,6 +1080,7 @@ async def chat(
async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"):
yield result
except ValueError as e:
program_execution_context.append(f"Failed to run code")
logger.warning(
f"Failed to use code tool: {e}. Attempting to respond without code results",
exc_info=True,
@ -1115,51 +1118,28 @@ async def chat(
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
generated_image, status_code, improved_image_prompt, intent_type = result
generated_image, status_code, improved_image_prompt = result
inferred_queries.append(improved_image_prompt)
if generated_image is None or status_code != 200:
content_obj = {
"content-type": "application/json",
"intentType": intent_type,
"detail": improved_image_prompt,
"image": None,
}
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
program_execution_context.append(f"Failed to generate image with {improved_image_prompt}")
async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"):
yield result
return
else:
generated_images.append(generated_image)
await sync_to_async(save_to_conversation_log)(
q,
generated_image,
user,
meta_log,
user_message_time,
intent_type=intent_type,
inferred_queries=[improved_image_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
)
content_obj = {
"intentType": intent_type,
"inferredQueries": [improved_image_prompt],
"image": generated_image,
}
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
async for result in send_event(
ChatEvent.GENERATED_ASSETS,
{
"images": [generated_image],
},
):
yield result
return
if ConversationCommand.Diagram in conversation_commands:
async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
yield result
intent_type = "excalidraw"
inferred_queries = []
diagram_description = ""
@ -1183,62 +1163,29 @@ async def chat(
if better_diagram_description_prompt and excalidraw_diagram_description:
inferred_queries.append(better_diagram_description_prompt)
diagram_description = excalidraw_diagram_description
generated_excalidraw_diagram = diagram_description
async for result in send_event(
ChatEvent.GENERATED_ASSETS,
{
"excalidrawDiagram": excalidraw_diagram_description,
},
):
yield result
else:
error_message = "Failed to generate diagram. Please try again later."
async for result in send_llm_response(error_message, tracer.get("usage")):
yield result
await sync_to_async(save_to_conversation_log)(
q,
error_message,
user,
meta_log,
user_message_time,
inferred_queries=[better_diagram_description_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
)
return
content_obj = {
"intentType": intent_type,
"inferredQueries": inferred_queries,
"image": diagram_description,
}
await sync_to_async(save_to_conversation_log)(
q,
excalidraw_diagram_description,
user,
meta_log,
user_message_time,
intent_type="excalidraw",
inferred_queries=[better_diagram_description_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
program_execution_context.append(
f"AI attempted to programmatically generate a diagram but failed due to a program issue. Generally, it is able to do so, but encountered a system issue this time. AI can suggest text description or rendering of the diagram or user can try again with a simpler prompt."
)
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
async for result in send_event(ChatEvent.STATUS, error_message):
yield result
return
## Generate Text Output
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
yield result
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
@ -1258,6 +1205,10 @@ async def chat(
train_of_thought,
attached_file_context,
raw_query_files,
generated_images,
generated_files,
generated_excalidraw_diagram,
program_execution_context,
tracer,
)

View file

@ -1185,6 +1185,10 @@ def generate_chat_response(
train_of_thought: List[Any] = [],
query_files: str = None,
raw_query_files: List[FileAttachment] = None,
generated_images: List[str] = None,
raw_generated_files: List[FileAttachment] = [],
generated_excalidraw_diagram: str = None,
program_execution_context: List[str] = [],
tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
@ -1208,6 +1212,9 @@ def generate_chat_response(
query_images=query_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
generated_images=generated_images,
raw_generated_files=raw_generated_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
tracer=tracer,
)
@ -1243,6 +1250,7 @@ def generate_chat_response(
user_name=user_name,
agent=agent,
query_files=query_files,
generated_files=raw_generated_files,
tracer=tracer,
)
@ -1269,6 +1277,10 @@ def generate_chat_response(
agent=agent,
vision_available=vision_available,
query_files=query_files,
generated_files=raw_generated_files,
generated_images=generated_images,
generated_excalidraw_diagram=generated_excalidraw_diagram,
program_execution_context=program_execution_context,
tracer=tracer,
)
@ -1292,6 +1304,10 @@ def generate_chat_response(
agent=agent,
vision_available=vision_available,
query_files=query_files,
generated_files=raw_generated_files,
generated_images=generated_images,
generated_excalidraw_diagram=generated_excalidraw_diagram,
program_execution_context=program_execution_context,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
@ -1314,6 +1330,10 @@ def generate_chat_response(
query_images=query_images,
vision_available=vision_available,
query_files=query_files,
generated_files=raw_generated_files,
generated_images=generated_images,
generated_excalidraw_diagram=generated_excalidraw_diagram,
program_execution_context=program_execution_context,
tracer=tracer,
)
@ -1785,6 +1805,9 @@ class MessageProcessor:
self.references = {}
self.usage = {}
self.raw_response = ""
self.generated_images = []
self.generated_files = []
self.generated_excalidraw_diagram = []
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
if raw_chunk.startswith("{") and raw_chunk.endswith("}"):
@ -1823,6 +1846,16 @@ class MessageProcessor:
self.raw_response += chunk_data
else:
self.raw_response += chunk_data
elif chunk_type == ChatEvent.GENERATED_ASSETS:
chunk_data = chunk["data"]
if isinstance(chunk_data, dict):
for key in chunk_data:
if key == "images":
self.generated_images = chunk_data[key]
elif key == "files":
self.generated_files = chunk_data[key]
elif key == "excalidrawDiagram":
self.generated_excalidraw_diagram = chunk_data[key]
def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]:
if "image" in json_data or "details" in json_data:
@ -1853,7 +1886,14 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
if buffer:
processor.process_message_chunk(buffer)
return {"response": processor.raw_response, "references": processor.references, "usage": processor.usage}
return {
"response": processor.raw_response,
"references": processor.references,
"usage": processor.usage,
"images": processor.generated_images,
"files": processor.generated_files,
"excalidrawDiagram": processor.generated_excalidraw_diagram,
}
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):

View file

@ -22,7 +22,6 @@ from khoj.processor.conversation.offline.chat_model import (
filter_questions,
)
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import message_to_log
from khoj.utils.constants import default_offline_chat_models

View file

@ -6,7 +6,6 @@ from freezegun import freeze_time
from khoj.database.models import Agent, Entry, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_api_key
# Initialize variables for tests