Merge pull request #992 from khoj-ai/features/allow-multi-outputs-in-chat

Currently, Khoj has terminal states with respect to what assets it outputs. We limit it to image, text, and excalidraw diagram. This limitation is unnecessary and provides undue constraints for creating more dynamic, diverse experiences. For instance, we may want the chat view to morph for document editing or generation, in which case having limited output modes would be a detriment.

This change allows us to emit generated assets and then continue on to more text generation in final response. It forces text response for all messages. It adds a new stream event, GENERATED_ASSETS, which holds content that the AI is emitting in response to the query.
This commit is contained in:
sabaimran 2024-12-08 14:19:05 -08:00 committed by GitHub
commit e3789aef49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 665 additions and 285 deletions

View file

@ -76,12 +76,12 @@ jobs:
DEBIAN_FRONTEND: noninteractive DEBIAN_FRONTEND: noninteractive
run: | run: |
# install postgres and other dependencies # install postgres and other dependencies
apt update && apt install -y git python3-pip libegl1 sqlite3 libsqlite3-dev libsqlite3-0 ffmpeg libsm6 libxext6 sudo apt update && sudo apt install -y git python3-pip libegl1 sqlite3 libsqlite3-dev libsqlite3-0 ffmpeg libsm6 libxext6
apt install -y postgresql postgresql-client && apt install -y postgresql-server-dev-14 sudo apt install -y postgresql postgresql-client && sudo apt install -y postgresql-server-dev-14
# upgrade pip # upgrade pip
python -m ensurepip --upgrade && python -m pip install --upgrade pip python -m ensurepip --upgrade && python -m pip install --upgrade pip
# install terrarium for code sandbox # install terrarium for code sandbox
git clone https://github.com/cohere-ai/cohere-terrarium.git && cd cohere-terrarium && npm install && mkdir pyodide_cache git clone https://github.com/khoj-ai/terrarium.git && cd terrarium && npm install --legacy-peer-deps && mkdir pyodide_cache
- name: ⬇️ Install Application - name: ⬇️ Install Application
run: | run: |
@ -113,7 +113,7 @@ jobs:
khoj --anonymous-mode --non-interactive & khoj --anonymous-mode --non-interactive &
# Start code sandbox # Start code sandbox
npm run dev --prefix cohere-terrarium & npm run dev --prefix terrarium &
# Wait for server to be ready # Wait for server to be ready
timeout=120 timeout=120

View file

@ -1,4 +1,4 @@
import {ItemView, MarkdownRenderer, Scope, WorkspaceLeaf, request, requestUrl, setIcon, Platform} from 'obsidian'; import { ItemView, MarkdownRenderer, Scope, WorkspaceLeaf, request, requestUrl, setIcon, Platform } from 'obsidian';
import * as DOMPurify from 'dompurify'; import * as DOMPurify from 'dompurify';
import { KhojSetting } from 'src/settings'; import { KhojSetting } from 'src/settings';
import { KhojPaneView } from 'src/pane_view'; import { KhojPaneView } from 'src/pane_view';
@ -27,6 +27,7 @@ interface ChatMessageState {
newResponseEl: HTMLElement | null; newResponseEl: HTMLElement | null;
loadingEllipsis: HTMLElement | null; loadingEllipsis: HTMLElement | null;
references: any; references: any;
generatedAssets: string;
rawResponse: string; rawResponse: string;
rawQuery: string; rawQuery: string;
isVoice: boolean; isVoice: boolean;
@ -46,10 +47,10 @@ export class KhojChatView extends KhojPaneView {
waitingForLocation: boolean; waitingForLocation: boolean;
location: Location = { timezone: Intl.DateTimeFormat().resolvedOptions().timeZone }; location: Location = { timezone: Intl.DateTimeFormat().resolvedOptions().timeZone };
keyPressTimeout: NodeJS.Timeout | null = null; keyPressTimeout: NodeJS.Timeout | null = null;
userMessages: string[] = []; // Store user sent messages for input history cycling userMessages: string[] = []; // Store user sent messages for input history cycling
currentMessageIndex: number = -1; // Track current message index in userMessages array currentMessageIndex: number = -1; // Track current message index in userMessages array
private currentUserInput: string = ""; // Stores the current user input that is being typed in chat private currentUserInput: string = ""; // Stores the current user input that is being typed in chat
private startingMessage: string = "Message"; private startingMessage: string = "Message";
chatMessageState: ChatMessageState; chatMessageState: ChatMessageState;
constructor(leaf: WorkspaceLeaf, setting: KhojSetting) { constructor(leaf: WorkspaceLeaf, setting: KhojSetting) {
@ -102,14 +103,14 @@ export class KhojChatView extends KhojPaneView {
// Clear text after extracting message to send // Clear text after extracting message to send
let user_message = input_el.value.trim(); let user_message = input_el.value.trim();
// Store the message in the array if it's not empty // Store the message in the array if it's not empty
if (user_message) { if (user_message) {
this.userMessages.push(user_message); this.userMessages.push(user_message);
// Update starting message after sending a new message // Update starting message after sending a new message
const modifierKey = Platform.isMacOS ? '⌘' : '^'; const modifierKey = Platform.isMacOS ? '⌘' : '^';
this.startingMessage = `(${modifierKey}+↑/↓) for prev messages`; this.startingMessage = `(${modifierKey}+↑/↓) for prev messages`;
input_el.placeholder = this.startingMessage; input_el.placeholder = this.startingMessage;
} }
input_el.value = ""; input_el.value = "";
this.autoResize(); this.autoResize();
@ -162,9 +163,9 @@ export class KhojChatView extends KhojPaneView {
}) })
chatInput.addEventListener('input', (_) => { this.onChatInput() }); chatInput.addEventListener('input', (_) => { this.onChatInput() });
chatInput.addEventListener('keydown', (event) => { chatInput.addEventListener('keydown', (event) => {
this.incrementalChat(event); this.incrementalChat(event);
this.handleArrowKeys(event); this.handleArrowKeys(event);
}); });
// Add event listeners for long press keybinding // Add event listeners for long press keybinding
this.contentEl.addEventListener('keydown', this.handleKeyDown.bind(this)); this.contentEl.addEventListener('keydown', this.handleKeyDown.bind(this));
@ -199,7 +200,7 @@ export class KhojChatView extends KhojPaneView {
// Get chat history from Khoj backend and set chat input state // Get chat history from Khoj backend and set chat input state
let getChatHistorySucessfully = await this.getChatHistory(chatBodyEl); let getChatHistorySucessfully = await this.getChatHistory(chatBodyEl);
let placeholderText : string = getChatHistorySucessfully ? this.startingMessage : "Configure Khoj to enable chat"; let placeholderText: string = getChatHistorySucessfully ? this.startingMessage : "Configure Khoj to enable chat";
chatInput.placeholder = placeholderText; chatInput.placeholder = placeholderText;
chatInput.disabled = !getChatHistorySucessfully; chatInput.disabled = !getChatHistorySucessfully;
@ -214,7 +215,7 @@ export class KhojChatView extends KhojPaneView {
}); });
} }
startSpeechToText(event: KeyboardEvent | MouseEvent | TouchEvent, timeout=200) { startSpeechToText(event: KeyboardEvent | MouseEvent | TouchEvent, timeout = 200) {
if (!this.keyPressTimeout) { if (!this.keyPressTimeout) {
this.keyPressTimeout = setTimeout(async () => { this.keyPressTimeout = setTimeout(async () => {
// Reset auto send voice message timer, UI if running // Reset auto send voice message timer, UI if running
@ -320,7 +321,7 @@ export class KhojChatView extends KhojPaneView {
referenceButton.tabIndex = 0; referenceButton.tabIndex = 0;
// Add event listener to toggle full reference on click // Add event listener to toggle full reference on click
referenceButton.addEventListener('click', function() { referenceButton.addEventListener('click', function () {
if (this.classList.contains("collapsed")) { if (this.classList.contains("collapsed")) {
this.classList.remove("collapsed"); this.classList.remove("collapsed");
this.classList.add("expanded"); this.classList.add("expanded");
@ -375,7 +376,7 @@ export class KhojChatView extends KhojPaneView {
referenceButton.tabIndex = 0; referenceButton.tabIndex = 0;
// Add event listener to toggle full reference on click // Add event listener to toggle full reference on click
referenceButton.addEventListener('click', function() { referenceButton.addEventListener('click', function () {
if (this.classList.contains("collapsed")) { if (this.classList.contains("collapsed")) {
this.classList.remove("collapsed"); this.classList.remove("collapsed");
this.classList.add("expanded"); this.classList.add("expanded");
@ -420,23 +421,23 @@ export class KhojChatView extends KhojPaneView {
"Authorization": `Bearer ${this.setting.khojApiKey}`, "Authorization": `Bearer ${this.setting.khojApiKey}`,
}, },
}) })
.then(response => response.arrayBuffer()) .then(response => response.arrayBuffer())
.then(arrayBuffer => context.decodeAudioData(arrayBuffer)) .then(arrayBuffer => context.decodeAudioData(arrayBuffer))
.then(audioBuffer => { .then(audioBuffer => {
const source = context.createBufferSource(); const source = context.createBufferSource();
source.buffer = audioBuffer; source.buffer = audioBuffer;
source.connect(context.destination); source.connect(context.destination);
source.start(0); source.start(0);
source.onended = function() { source.onended = function () {
speechButton.removeChild(loader);
speechButton.disabled = false;
};
})
.catch(err => {
console.error("Error playing speech:", err);
speechButton.removeChild(loader); speechButton.removeChild(loader);
speechButton.disabled = false; speechButton.disabled = false; // Consider enabling the button again to allow retrying
}; });
})
.catch(err => {
console.error("Error playing speech:", err);
speechButton.removeChild(loader);
speechButton.disabled = false; // Consider enabling the button again to allow retrying
});
} }
formatHTMLMessage(message: string, raw = false, willReplace = true) { formatHTMLMessage(message: string, raw = false, willReplace = true) {
@ -485,12 +486,18 @@ export class KhojChatView extends KhojPaneView {
intentType?: string, intentType?: string,
inferredQueries?: string[], inferredQueries?: string[],
conversationId?: string, conversationId?: string,
images?: string[],
excalidrawDiagram?: string
) { ) {
if (!message) return; if (!message) return;
let chatMessageEl; let chatMessageEl;
if (intentType?.includes("text-to-image") || intentType === "excalidraw") { if (
let imageMarkdown = this.generateImageMarkdown(message, intentType, inferredQueries, conversationId); intentType?.includes("text-to-image") ||
intentType === "excalidraw" ||
(images && images.length > 0) ||
excalidrawDiagram) {
let imageMarkdown = this.generateImageMarkdown(message, intentType ?? "", inferredQueries, conversationId, images, excalidrawDiagram);
chatMessageEl = this.renderMessage(chatEl, imageMarkdown, sender, dt); chatMessageEl = this.renderMessage(chatEl, imageMarkdown, sender, dt);
} else { } else {
chatMessageEl = this.renderMessage(chatEl, message, sender, dt); chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
@ -510,7 +517,7 @@ export class KhojChatView extends KhojPaneView {
chatMessageBodyEl.appendChild(this.createReferenceSection(references)); chatMessageBodyEl.appendChild(this.createReferenceSection(references));
} }
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string): string { generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string, images?: string[], excalidrawDiagram?: string): string {
let imageMarkdown = ""; let imageMarkdown = "";
if (intentType === "text-to-image") { if (intentType === "text-to-image") {
imageMarkdown = `![](data:image/png;base64,${message})`; imageMarkdown = `![](data:image/png;base64,${message})`;
@ -518,12 +525,23 @@ export class KhojChatView extends KhojPaneView {
imageMarkdown = `![](${message})`; imageMarkdown = `![](${message})`;
} else if (intentType === "text-to-image-v3") { } else if (intentType === "text-to-image-v3") {
imageMarkdown = `![](data:image/webp;base64,${message})`; imageMarkdown = `![](data:image/webp;base64,${message})`;
} else if (intentType === "excalidraw") { } else if (intentType === "excalidraw" || excalidrawDiagram) {
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`; const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}chat?conversationId=${conversationId}`; const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}chat?conversationId=${conversationId}`;
imageMarkdown = redirectMessage; imageMarkdown = redirectMessage;
} else if (images && images.length > 0) {
for (let image of images) {
if (image.startsWith("https://")) {
imageMarkdown += `![](${image})\n\n`;
} else {
imageMarkdown += `![](data:image/png;base64,${image})\n\n`;
}
}
imageMarkdown += `${message}`;
} }
if (inferredQueries) {
if (images?.length === 0 && inferredQueries) {
imageMarkdown += "\n\n**Inferred Query**:"; imageMarkdown += "\n\n**Inferred Query**:";
for (let inferredQuery of inferredQueries) { for (let inferredQuery of inferredQueries) {
imageMarkdown += `\n\n${inferredQuery}`; imageMarkdown += `\n\n${inferredQuery}`;
@ -650,19 +668,19 @@ export class KhojChatView extends KhojPaneView {
chatBodyEl.innerHTML = ""; chatBodyEl.innerHTML = "";
chatBodyEl.dataset.conversationId = ""; chatBodyEl.dataset.conversationId = "";
chatBodyEl.dataset.conversationTitle = ""; chatBodyEl.dataset.conversationTitle = "";
this.userMessages = []; this.userMessages = [];
this.startingMessage = "Message"; this.startingMessage = "Message";
// Update the placeholder of the chat input // Update the placeholder of the chat input
const chatInput = this.contentEl.querySelector('.khoj-chat-input') as HTMLTextAreaElement; const chatInput = this.contentEl.querySelector('.khoj-chat-input') as HTMLTextAreaElement;
if (chatInput) { if (chatInput) {
chatInput.placeholder = this.startingMessage; chatInput.placeholder = this.startingMessage;
} }
this.renderMessage(chatBodyEl, "Hey 👋🏾, what's up?", "khoj"); this.renderMessage(chatBodyEl, "Hey 👋🏾, what's up?", "khoj");
} }
async toggleChatSessions(forceShow: boolean = false): Promise<boolean> { async toggleChatSessions(forceShow: boolean = false): Promise<boolean> {
this.userMessages = []; // clear user previous message history this.userMessages = []; // clear user previous message history
let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0] as HTMLElement; let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0] as HTMLElement;
if (!forceShow && this.contentEl.getElementsByClassName("side-panel")?.length > 0) { if (!forceShow && this.contentEl.getElementsByClassName("side-panel")?.length > 0) {
chatBodyEl.innerHTML = ""; chatBodyEl.innerHTML = "";
@ -768,10 +786,10 @@ export class KhojChatView extends KhojPaneView {
let editConversationTitleInputEl = this.contentEl.createEl('input'); let editConversationTitleInputEl = this.contentEl.createEl('input');
editConversationTitleInputEl.classList.add("conversation-title-input"); editConversationTitleInputEl.classList.add("conversation-title-input");
editConversationTitleInputEl.value = conversationTitle; editConversationTitleInputEl.value = conversationTitle;
editConversationTitleInputEl.addEventListener('click', function(event) { editConversationTitleInputEl.addEventListener('click', function (event) {
event.stopPropagation(); event.stopPropagation();
}); });
editConversationTitleInputEl.addEventListener('keydown', function(event) { editConversationTitleInputEl.addEventListener('keydown', function (event) {
if (event.key === "Enter") { if (event.key === "Enter") {
event.preventDefault(); event.preventDefault();
editConversationTitleSaveButtonEl.click(); editConversationTitleSaveButtonEl.click();
@ -890,15 +908,17 @@ export class KhojChatView extends KhojPaneView {
chatLog.intent?.type, chatLog.intent?.type,
chatLog.intent?.["inferred-queries"], chatLog.intent?.["inferred-queries"],
chatBodyEl.dataset.conversationId ?? "", chatBodyEl.dataset.conversationId ?? "",
chatLog.images,
chatLog.excalidrawDiagram,
); );
// push the user messages to the chat history // push the user messages to the chat history
if(chatLog.by === "you"){ if (chatLog.by === "you") {
this.userMessages.push(chatLog.message); this.userMessages.push(chatLog.message);
} }
}); });
// Update starting message after loading history // Update starting message after loading history
const modifierKey: string = Platform.isMacOS ? '⌘' : '^'; const modifierKey: string = Platform.isMacOS ? '⌘' : '^';
this.startingMessage = this.userMessages.length > 0 this.startingMessage = this.userMessages.length > 0
? `(${modifierKey}+↑/↓) for prev messages` ? `(${modifierKey}+↑/↓) for prev messages`
: "Message"; : "Message";
@ -922,15 +942,15 @@ export class KhojChatView extends KhojPaneView {
try { try {
let jsonChunk = JSON.parse(rawChunk); let jsonChunk = JSON.parse(rawChunk);
if (!jsonChunk.type) if (!jsonChunk.type)
jsonChunk = {type: 'message', data: jsonChunk}; jsonChunk = { type: 'message', data: jsonChunk };
return jsonChunk; return jsonChunk;
} catch (e) { } catch (e) {
return {type: 'message', data: rawChunk}; return { type: 'message', data: rawChunk };
} }
} else if (rawChunk.length > 0) { } else if (rawChunk.length > 0) {
return {type: 'message', data: rawChunk}; return { type: 'message', data: rawChunk };
} }
return {type: '', data: ''}; return { type: '', data: '' };
} }
processMessageChunk(rawChunk: string): void { processMessageChunk(rawChunk: string): void {
@ -941,6 +961,11 @@ export class KhojChatView extends KhojPaneView {
console.log(`status: ${chunk.data}`); console.log(`status: ${chunk.data}`);
const statusMessage = chunk.data; const statusMessage = chunk.data;
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, statusMessage, this.chatMessageState.loadingEllipsis, false); this.handleStreamResponse(this.chatMessageState.newResponseTextEl, statusMessage, this.chatMessageState.loadingEllipsis, false);
} else if (chunk.type === 'generated_assets') {
const generatedAssets = chunk.data;
const imageData = this.handleImageResponse(generatedAssets, this.chatMessageState.rawResponse);
this.chatMessageState.generatedAssets = imageData;
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, imageData, this.chatMessageState.loadingEllipsis, false);
} else if (chunk.type === 'start_llm_response') { } else if (chunk.type === 'start_llm_response') {
console.log("Started streaming", new Date()); console.log("Started streaming", new Date());
} else if (chunk.type === 'end_llm_response') { } else if (chunk.type === 'end_llm_response') {
@ -963,9 +988,10 @@ export class KhojChatView extends KhojPaneView {
rawResponse: "", rawResponse: "",
rawQuery: liveQuery, rawQuery: liveQuery,
isVoice: false, isVoice: false,
generatedAssets: "",
}; };
} else if (chunk.type === "references") { } else if (chunk.type === "references") {
this.chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext}; this.chatMessageState.references = { "notes": chunk.data.context, "online": chunk.data.onlineContext };
} else if (chunk.type === 'message') { } else if (chunk.type === 'message') {
const chunkData = chunk.data; const chunkData = chunk.data;
if (typeof chunkData === 'object' && chunkData !== null) { if (typeof chunkData === 'object' && chunkData !== null) {
@ -978,17 +1004,17 @@ export class KhojChatView extends KhojPaneView {
this.handleJsonResponse(jsonData); this.handleJsonResponse(jsonData);
} catch (e) { } catch (e) {
this.chatMessageState.rawResponse += chunkData; this.chatMessageState.rawResponse += chunkData;
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis); this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse + this.chatMessageState.generatedAssets, this.chatMessageState.loadingEllipsis);
} }
} else { } else {
this.chatMessageState.rawResponse += chunkData; this.chatMessageState.rawResponse += chunkData;
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis); this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse + this.chatMessageState.generatedAssets, this.chatMessageState.loadingEllipsis);
} }
} }
} }
handleJsonResponse(jsonData: any): void { handleJsonResponse(jsonData: any): void {
if (jsonData.image || jsonData.detail) { if (jsonData.image || jsonData.detail || jsonData.images || jsonData.excalidrawDiagram) {
this.chatMessageState.rawResponse = this.handleImageResponse(jsonData, this.chatMessageState.rawResponse); this.chatMessageState.rawResponse = this.handleImageResponse(jsonData, this.chatMessageState.rawResponse);
} else if (jsonData.response) { } else if (jsonData.response) {
this.chatMessageState.rawResponse = jsonData.response; this.chatMessageState.rawResponse = jsonData.response;
@ -1234,11 +1260,11 @@ export class KhojChatView extends KhojPaneView {
const recordingConfig = { mimeType: 'audio/webm' }; const recordingConfig = { mimeType: 'audio/webm' };
this.mediaRecorder = new MediaRecorder(stream, recordingConfig); this.mediaRecorder = new MediaRecorder(stream, recordingConfig);
this.mediaRecorder.addEventListener("dataavailable", function(event) { this.mediaRecorder.addEventListener("dataavailable", function (event) {
if (event.data.size > 0) audioChunks.push(event.data); if (event.data.size > 0) audioChunks.push(event.data);
}); });
this.mediaRecorder.addEventListener("stop", async function() { this.mediaRecorder.addEventListener("stop", async function () {
const audioBlob = new Blob(audioChunks, { type: 'audio/webm' }); const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
await sendToServer(audioBlob); await sendToServer(audioBlob);
}); });
@ -1368,7 +1394,22 @@ export class KhojChatView extends KhojPaneView {
if (inferredQuery) { if (inferredQuery) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
} }
} else if (imageJson.images) {
// If response has images field, response is a list of generated images.
imageJson.images.forEach((image: any) => {
if (image.startsWith("http")) {
rawResponse += `![generated_image](${image})\n\n`;
} else {
rawResponse += `![generated_image](data:image/png;base64,${image})\n\n`;
}
});
} else if (imageJson.excalidrawDiagram) {
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}`;
rawResponse += redirectMessage;
} }
// If response has detail field, response is an error message. // If response has detail field, response is an error message.
if (imageJson.detail) rawResponse += imageJson.detail; if (imageJson.detail) rawResponse += imageJson.detail;
@ -1407,7 +1448,7 @@ export class KhojChatView extends KhojPaneView {
referenceExpandButton.classList.add("reference-expand-button"); referenceExpandButton.classList.add("reference-expand-button");
referenceExpandButton.innerHTML = numReferences == 1 ? "1 reference" : `${numReferences} references`; referenceExpandButton.innerHTML = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.addEventListener('click', function() { referenceExpandButton.addEventListener('click', function () {
if (referenceSection.classList.contains("collapsed")) { if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed"); referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded"); referenceSection.classList.add("expanded");

View file

@ -82,7 +82,8 @@ If your plugin does not need CSS, delete this file.
} }
/* color chat bubble by khoj blue */ /* color chat bubble by khoj blue */
.khoj-chat-message-text.khoj { .khoj-chat-message-text.khoj {
border: 1px solid var(--khoj-sun); border-top: 1px solid var(--khoj-sun);
border-radius: 0px;
margin-left: auto; margin-left: auto;
white-space: pre-line; white-space: pre-line;
} }
@ -104,8 +105,9 @@ If your plugin does not need CSS, delete this file.
} }
/* color chat bubble by you dark grey */ /* color chat bubble by you dark grey */
.khoj-chat-message-text.you { .khoj-chat-message-text.you {
border: 1px solid var(--color-accent); color: var(--text-normal);
margin-right: auto; margin-right: auto;
background-color: var(--background-modifier-cover);
} }
/* add right protrusion to you chat bubble */ /* add right protrusion to you chat bubble */
.khoj-chat-message-text.you:after { .khoj-chat-message-text.you:after {

View file

@ -1,3 +1,4 @@
import { AttachedFileText } from "../components/chatInputArea/chatInputArea";
import { import {
CodeContext, CodeContext,
Context, Context,
@ -16,6 +17,12 @@ export interface MessageMetadata {
turnId: string; turnId: string;
} }
export interface GeneratedAssetsData {
images: string[];
excalidrawDiagram: string;
files: AttachedFileText[];
}
export interface ResponseWithIntent { export interface ResponseWithIntent {
intentType: string; intentType: string;
response: string; response: string;
@ -84,6 +91,8 @@ export function processMessageChunk(
if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext, codeContext }; if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext, codeContext };
console.log(`chunk type: ${chunk.type}`);
if (chunk.type === "status") { if (chunk.type === "status") {
console.log(`status: ${chunk.data}`); console.log(`status: ${chunk.data}`);
const statusMessage = chunk.data as string; const statusMessage = chunk.data as string;
@ -98,6 +107,20 @@ export function processMessageChunk(
} else if (chunk.type === "metadata") { } else if (chunk.type === "metadata") {
const messageMetadata = chunk.data as MessageMetadata; const messageMetadata = chunk.data as MessageMetadata;
currentMessage.turnId = messageMetadata.turnId; currentMessage.turnId = messageMetadata.turnId;
} else if (chunk.type === "generated_assets") {
const generatedAssets = chunk.data as GeneratedAssetsData;
if (generatedAssets.images) {
currentMessage.generatedImages = generatedAssets.images;
}
if (generatedAssets.excalidrawDiagram) {
currentMessage.generatedExcalidrawDiagram = generatedAssets.excalidrawDiagram;
}
if (generatedAssets.files) {
currentMessage.generatedFiles = generatedAssets.files;
}
} else if (chunk.type === "message") { } else if (chunk.type === "message") {
const chunkData = chunk.data; const chunkData = chunk.data;
// Here, handle if the response is a JSON response with an image, but the intentType is excalidraw // Here, handle if the response is a JSON response with an image, but the intentType is excalidraw

View file

@ -54,6 +54,12 @@ function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
const lastIndex = props.trainOfThought.length - 1; const lastIndex = props.trainOfThought.length - 1;
const [collapsed, setCollapsed] = useState(props.completed); const [collapsed, setCollapsed] = useState(props.completed);
useEffect(() => {
if (props.completed) {
setCollapsed(true);
}
}, [props.completed]);
return ( return (
<div <div
className={`${!collapsed ? styles.trainOfThought + " shadow-sm" : ""}`} className={`${!collapsed ? styles.trainOfThought + " shadow-sm" : ""}`}
@ -410,6 +416,9 @@ export default function ChatHistory(props: ChatHistoryProps) {
"inferred-queries": message.inferredQueries || [], "inferred-queries": message.inferredQueries || [],
}, },
conversationId: props.conversationId, conversationId: props.conversationId,
images: message.generatedImages,
queryFiles: message.generatedFiles,
excalidrawDiagram: message.generatedExcalidrawDiagram,
turnId: messageTurnId, turnId: messageTurnId,
}} }}
conversationId={props.conversationId} conversationId={props.conversationId}

View file

@ -77,6 +77,21 @@ div.imageWrapper img {
border-radius: 8px; border-radius: 8px;
} }
div.khoj div.imageWrapper img {
height: 512px;
}
div.khoj div.imageWrapper {
flex: 1 1 auto;
}
div.khoj div.imagesContainer {
display: flex;
flex-wrap: wrap;
flex-direction: row;
overflow-x: hidden;
}
div.chatMessageContainer > img { div.chatMessageContainer > img {
width: auto; width: auto;
height: auto; height: auto;
@ -178,4 +193,9 @@ div.trainOfThoughtElement ul {
div.youfullHistory { div.youfullHistory {
max-width: 90%; max-width: 90%;
} }
div.khoj div.imageWrapper img {
width: 100%;
height: auto;
}
} }

View file

@ -163,6 +163,7 @@ export interface SingleChatMessage {
conversationId: string; conversationId: string;
turnId?: string; turnId?: string;
queryFiles?: AttachedFileText[]; queryFiles?: AttachedFileText[];
excalidrawDiagram?: string;
} }
export interface StreamMessage { export interface StreamMessage {
@ -180,6 +181,10 @@ export interface StreamMessage {
inferredQueries?: string[]; inferredQueries?: string[];
turnId?: string; turnId?: string;
queryFiles?: AttachedFileText[]; queryFiles?: AttachedFileText[];
excalidrawDiagram?: string;
generatedFiles?: AttachedFileText[];
generatedImages?: string[];
generatedExcalidrawDiagram?: string;
} }
export interface ChatHistoryData { export interface ChatHistoryData {
@ -264,6 +269,9 @@ interface ChatMessageProps {
onDeleteMessage: (turnId?: string) => void; onDeleteMessage: (turnId?: string) => void;
conversationId: string; conversationId: string;
turnId?: string; turnId?: string;
generatedImage?: string;
excalidrawDiagram?: string;
generatedFiles?: AttachedFileText[];
} }
interface TrainOfThoughtProps { interface TrainOfThoughtProps {
@ -389,9 +397,8 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
// Prepare initial message for rendering // Prepare initial message for rendering
let message = props.chatMessage.message; let message = props.chatMessage.message;
if (props.chatMessage.intent && props.chatMessage.intent.type == "excalidraw") { if (props.chatMessage.excalidrawDiagram) {
message = props.chatMessage.intent["inferred-queries"][0]; setExcalidrawData(props.chatMessage.excalidrawDiagram);
setExcalidrawData(props.chatMessage.message);
} }
// Replace LaTeX delimiters with placeholders // Replace LaTeX delimiters with placeholders
@ -401,27 +408,6 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
.replace(/\\\[/g, "LEFTBRACKET") .replace(/\\\[/g, "LEFTBRACKET")
.replace(/\\\]/g, "RIGHTBRACKET"); .replace(/\\\]/g, "RIGHTBRACKET");
const intentTypeHandlers = {
"text-to-image": (msg: string) => `![generated image](data:image/png;base64,${msg})`,
"text-to-image2": (msg: string) => `![generated image](${msg})`,
"text-to-image-v3": (msg: string) =>
`![generated image](data:image/webp;base64,${msg})`,
excalidraw: (msg: string) => msg,
};
// Handle intent-specific rendering
if (props.chatMessage.intent) {
const { type, "inferred-queries": inferredQueries } = props.chatMessage.intent;
if (type in intentTypeHandlers) {
message = intentTypeHandlers[type as keyof typeof intentTypeHandlers](message);
}
if (type.includes("text-to-image") && inferredQueries?.length > 0) {
message += `\n\n${inferredQueries[0]}`;
}
}
// Replace file links with base64 data // Replace file links with base64 data
message = renderCodeGenImageInline(message, props.chatMessage.codeContext); message = renderCodeGenImageInline(message, props.chatMessage.codeContext);

View file

@ -303,17 +303,10 @@ class ConversationAdmin(unfold_admin.ModelAdmin):
modified_log = conversation.conversation_log modified_log = conversation.conversation_log
chat_log = modified_log.get("chat", []) chat_log = modified_log.get("chat", [])
for idx, log in enumerate(chat_log): for idx, log in enumerate(chat_log):
if ( if log["by"] == "khoj" and log["images"]:
log["by"] == "khoj" log["images"] = ["inline image redacted for space"]
and log["intent"]
and log["intent"]["type"]
and (
log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
)
):
log["message"] = "inline image redacted for space"
chat_log[idx] = log chat_log[idx] = log
modified_log["chat"] = chat_log modified_log["chat"] = chat_log
writer.writerow( writer.writerow(

View file

@ -0,0 +1,85 @@
# Made manually by sabaimran for use by Django 5.0.9 on 2024-12-01 16:59
from django.db import migrations, models
# This script was written alongside when Pydantic validation was added to the Conversation conversation_log field.
def migrate_generated_assets(apps, schema_editor):
Conversation = apps.get_model("database", "Conversation")
# Process conversations in chunks
for conversation in Conversation.objects.iterator():
try:
meta_log = conversation.conversation_log
modified = False
for chat in meta_log.get("chat", []):
intent_type = chat.get("intent", {}).get("type")
if intent_type and chat["by"] == "khoj":
if intent_type and "text-to-image" in intent_type:
# Migrate the generated image to the new format
chat["images"] = [chat.get("message")]
chat["message"] = chat["intent"]["inferred-queries"][0]
modified = True
if intent_type and "excalidraw" in intent_type:
# Migrate the generated excalidraw to the new format
chat["excalidrawDiagram"] = chat.get("message")
chat["message"] = chat["intent"]["inferred-queries"][0]
modified = True
# Only save if changes were made
if modified:
conversation.conversation_log = meta_log
conversation.save()
except Exception as e:
print(f"Error processing conversation {conversation.id}: {str(e)}")
continue
def reverse_migration(apps, schema_editor):
Conversation = apps.get_model("database", "Conversation")
# Process conversations in chunks
for conversation in Conversation.objects.iterator():
try:
meta_log = conversation.conversation_log
modified = False
for chat in meta_log.get("chat", []):
intent_type = chat.get("intent", {}).get("type")
if intent_type and chat["by"] == "khoj":
if intent_type and "text-to-image" in intent_type:
# Migrate the generated image back to the old format
chat["message"] = chat.get("images", [])[0]
chat.pop("images", None)
modified = True
if intent_type and "excalidraw" in intent_type:
# Migrate the generated excalidraw back to the old format
chat["message"] = chat.get("excalidrawDiagram")
chat.pop("excalidrawDiagram", None)
modified = True
# Only save if changes were made
if modified:
conversation.conversation_log = meta_log
conversation.save()
except Exception as e:
print(f"Error processing conversation {conversation.id}: {str(e)}")
continue
class Migration(migrations.Migration):
dependencies = [
("database", "0074_alter_conversation_title"),
]
operations = [
migrations.RunPython(migrate_generated_assets, reverse_migration),
]

View file

@ -1,7 +1,9 @@
import logging
import os import os
import re import re
import uuid import uuid
from random import choice from random import choice
from typing import Dict, List, Optional, Union
from django.contrib.auth.models import AbstractUser from django.contrib.auth.models import AbstractUser
from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields import ArrayField
@ -11,9 +13,109 @@ from django.db.models.signals import pre_save
from django.dispatch import receiver from django.dispatch import receiver
from pgvector.django import VectorField from pgvector.django import VectorField
from phonenumber_field.modelfields import PhoneNumberField from phonenumber_field.modelfields import PhoneNumberField
from pydantic import BaseModel as PydanticBaseModel
from pydantic import Field
logger = logging.getLogger(__name__)
class BaseModel(models.Model): # Pydantic models for type Chat Message validation
class Context(PydanticBaseModel):
compiled: str
file: str
class CodeContextFile(PydanticBaseModel):
filename: str
b64_data: str
class CodeContextResult(PydanticBaseModel):
success: bool
output_files: List[CodeContextFile]
std_out: str
std_err: str
code_runtime: int
class CodeContextData(PydanticBaseModel):
code: str
result: Optional[CodeContextResult] = None
class WebPage(PydanticBaseModel):
link: str
query: Optional[str] = None
snippet: str
class AnswerBox(PydanticBaseModel):
link: Optional[str] = None
snippet: Optional[str] = None
title: str
snippetHighlighted: Optional[List[str]] = None
class PeopleAlsoAsk(PydanticBaseModel):
link: Optional[str] = None
question: Optional[str] = None
snippet: Optional[str] = None
title: str
class KnowledgeGraph(PydanticBaseModel):
attributes: Optional[Dict[str, str]] = None
description: Optional[str] = None
descriptionLink: Optional[str] = None
descriptionSource: Optional[str] = None
imageUrl: Optional[str] = None
title: str
type: Optional[str] = None
class OrganicContext(PydanticBaseModel):
snippet: str
title: str
link: str
class OnlineContext(PydanticBaseModel):
webpages: Optional[Union[WebPage, List[WebPage]]] = None
answerBox: Optional[AnswerBox] = None
peopleAlsoAsk: Optional[List[PeopleAlsoAsk]] = None
knowledgeGraph: Optional[KnowledgeGraph] = None
organicContext: Optional[List[OrganicContext]] = None
class Intent(PydanticBaseModel):
type: str
query: str
memory_type: str = Field(alias="memory-type")
inferred_queries: Optional[List[str]] = Field(default=None, alias="inferred-queries")
class TrainOfThought(PydanticBaseModel):
type: str
data: str
class ChatMessage(PydanticBaseModel):
message: str
trainOfThought: List[TrainOfThought] = []
context: List[Context] = []
onlineContext: Dict[str, OnlineContext] = {}
codeContext: Dict[str, CodeContextData] = {}
created: str
images: Optional[List[str]] = None
queryFiles: Optional[List[Dict]] = None
excalidrawDiagram: Optional[List[Dict]] = None
by: str
turnId: Optional[str] = None
intent: Optional[Intent] = None
automationId: Optional[str] = None
class DbBaseModel(models.Model):
created_at = models.DateTimeField(auto_now_add=True) created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True) updated_at = models.DateTimeField(auto_now=True)
@ -21,7 +123,7 @@ class BaseModel(models.Model):
abstract = True abstract = True
class ClientApplication(BaseModel): class ClientApplication(DbBaseModel):
name = models.CharField(max_length=200) name = models.CharField(max_length=200)
client_id = models.CharField(max_length=200) client_id = models.CharField(max_length=200)
client_secret = models.CharField(max_length=200) client_secret = models.CharField(max_length=200)
@ -67,7 +169,7 @@ class KhojApiUser(models.Model):
accessed_at = models.DateTimeField(null=True, default=None) accessed_at = models.DateTimeField(null=True, default=None)
class Subscription(BaseModel): class Subscription(DbBaseModel):
class Type(models.TextChoices): class Type(models.TextChoices):
TRIAL = "trial" TRIAL = "trial"
STANDARD = "standard" STANDARD = "standard"
@ -79,13 +181,13 @@ class Subscription(BaseModel):
enabled_trial_at = models.DateTimeField(null=True, default=None, blank=True) enabled_trial_at = models.DateTimeField(null=True, default=None, blank=True)
class OpenAIProcessorConversationConfig(BaseModel): class OpenAIProcessorConversationConfig(DbBaseModel):
name = models.CharField(max_length=200) name = models.CharField(max_length=200)
api_key = models.CharField(max_length=200) api_key = models.CharField(max_length=200)
api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True) api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True)
class ChatModelOptions(BaseModel): class ChatModelOptions(DbBaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
OPENAI = "openai" OPENAI = "openai"
OFFLINE = "offline" OFFLINE = "offline"
@ -103,12 +205,12 @@ class ChatModelOptions(BaseModel):
) )
class VoiceModelOption(BaseModel): class VoiceModelOption(DbBaseModel):
model_id = models.CharField(max_length=200) model_id = models.CharField(max_length=200)
name = models.CharField(max_length=200) name = models.CharField(max_length=200)
class Agent(BaseModel): class Agent(DbBaseModel):
class StyleColorTypes(models.TextChoices): class StyleColorTypes(models.TextChoices):
BLUE = "blue" BLUE = "blue"
GREEN = "green" GREEN = "green"
@ -208,7 +310,7 @@ class Agent(BaseModel):
super().save(*args, **kwargs) super().save(*args, **kwargs)
class ProcessLock(BaseModel): class ProcessLock(DbBaseModel):
class Operation(models.TextChoices): class Operation(models.TextChoices):
INDEX_CONTENT = "index_content" INDEX_CONTENT = "index_content"
SCHEDULED_JOB = "scheduled_job" SCHEDULED_JOB = "scheduled_job"
@ -231,24 +333,24 @@ def verify_agent(sender, instance, **kwargs):
raise ValidationError(f"A private Agent with the name {instance.name} already exists.") raise ValidationError(f"A private Agent with the name {instance.name} already exists.")
class NotionConfig(BaseModel): class NotionConfig(DbBaseModel):
token = models.CharField(max_length=200) token = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class GithubConfig(BaseModel): class GithubConfig(DbBaseModel):
pat_token = models.CharField(max_length=200) pat_token = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class GithubRepoConfig(BaseModel): class GithubRepoConfig(DbBaseModel):
name = models.CharField(max_length=200) name = models.CharField(max_length=200)
owner = models.CharField(max_length=200) owner = models.CharField(max_length=200)
branch = models.CharField(max_length=200) branch = models.CharField(max_length=200)
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig") github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
class WebScraper(BaseModel): class WebScraper(DbBaseModel):
class WebScraperType(models.TextChoices): class WebScraperType(models.TextChoices):
FIRECRAWL = "Firecrawl" FIRECRAWL = "Firecrawl"
OLOSTEP = "Olostep" OLOSTEP = "Olostep"
@ -321,7 +423,7 @@ class WebScraper(BaseModel):
super().save(*args, **kwargs) super().save(*args, **kwargs)
class ServerChatSettings(BaseModel): class ServerChatSettings(DbBaseModel):
chat_default = models.ForeignKey( chat_default = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default" ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
) )
@ -333,35 +435,35 @@ class ServerChatSettings(BaseModel):
) )
class LocalOrgConfig(BaseModel): class LocalOrgConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True) input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True) input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False) index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalMarkdownConfig(BaseModel): class LocalMarkdownConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True) input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True) input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False) index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalPdfConfig(BaseModel): class LocalPdfConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True) input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True) input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False) index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalPlaintextConfig(BaseModel): class LocalPlaintextConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True) input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True) input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False) index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class SearchModelConfig(BaseModel): class SearchModelConfig(DbBaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
TEXT = "text" TEXT = "text"
@ -393,7 +495,7 @@ class SearchModelConfig(BaseModel):
bi_encoder_confidence_threshold = models.FloatField(default=0.18) bi_encoder_confidence_threshold = models.FloatField(default=0.18)
class TextToImageModelConfig(BaseModel): class TextToImageModelConfig(DbBaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
OPENAI = "openai" OPENAI = "openai"
STABILITYAI = "stability-ai" STABILITYAI = "stability-ai"
@ -430,7 +532,7 @@ class TextToImageModelConfig(BaseModel):
super().save(*args, **kwargs) super().save(*args, **kwargs)
class SpeechToTextModelOptions(BaseModel): class SpeechToTextModelOptions(DbBaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
OPENAI = "openai" OPENAI = "openai"
OFFLINE = "offline" OFFLINE = "offline"
@ -439,22 +541,22 @@ class SpeechToTextModelOptions(BaseModel):
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
class UserConversationConfig(BaseModel): class UserConversationConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True) setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
class UserVoiceModelConfig(BaseModel): class UserVoiceModelConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True) setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
class UserTextToImageModelConfig(BaseModel): class UserTextToImageModelConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE) setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
class Conversation(BaseModel): class Conversation(DbBaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict) conversation_log = models.JSONField(default=dict)
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True) client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
@ -468,8 +570,39 @@ class Conversation(BaseModel):
file_filters = models.JSONField(default=list) file_filters = models.JSONField(default=list)
id = models.UUIDField(default=uuid.uuid4, editable=False, unique=True, primary_key=True, db_index=True) id = models.UUIDField(default=uuid.uuid4, editable=False, unique=True, primary_key=True, db_index=True)
def clean(self):
# Validate conversation_log structure
try:
messages = self.conversation_log.get("chat", [])
for msg in messages:
ChatMessage.model_validate(msg)
except Exception as e:
raise ValidationError(f"Invalid conversation_log format: {str(e)}")
class PublicConversation(BaseModel): def save(self, *args, **kwargs):
self.clean()
super().save(*args, **kwargs)
@property
def messages(self) -> List[ChatMessage]:
"""Type-hinted accessor for conversation messages"""
validated_messages = []
for msg in self.conversation_log.get("chat", []):
try:
# Clean up inferred queries if they contain None
if msg.get("intent") and msg["intent"].get("inferred-queries"):
msg["intent"]["inferred-queries"] = [
q for q in msg["intent"]["inferred-queries"] if q is not None and isinstance(q, str)
]
msg["message"] = str(msg.get("message", ""))
validated_messages.append(ChatMessage.model_validate(msg))
except ValidationError as e:
logger.warning(f"Skipping invalid message in conversation: {e}")
continue
return validated_messages
class PublicConversation(DbBaseModel):
source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE) source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict) conversation_log = models.JSONField(default=dict)
slug = models.CharField(max_length=200, default=None, null=True, blank=True) slug = models.CharField(max_length=200, default=None, null=True, blank=True)
@ -499,12 +632,12 @@ def verify_public_conversation(sender, instance, **kwargs):
instance.slug = slug instance.slug = slug
class ReflectiveQuestion(BaseModel): class ReflectiveQuestion(DbBaseModel):
question = models.CharField(max_length=500) question = models.CharField(max_length=500)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
class Entry(BaseModel): class Entry(DbBaseModel):
class EntryType(models.TextChoices): class EntryType(models.TextChoices):
IMAGE = "image" IMAGE = "image"
PDF = "pdf" PDF = "pdf"
@ -541,7 +674,7 @@ class Entry(BaseModel):
raise ValidationError("An Entry cannot be associated with both a user and an agent.") raise ValidationError("An Entry cannot be associated with both a user and an agent.")
class FileObject(BaseModel): class FileObject(DbBaseModel):
# Same as Entry but raw will be a much larger string # Same as Entry but raw will be a much larger string
file_name = models.CharField(max_length=400, default=None, null=True, blank=True) file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
raw_text = models.TextField() raw_text = models.TextField()
@ -549,7 +682,7 @@ class FileObject(BaseModel):
agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True) agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
class EntryDates(BaseModel): class EntryDates(DbBaseModel):
date = models.DateField() date = models.DateField()
entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates") entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")
@ -559,12 +692,12 @@ class EntryDates(BaseModel):
] ]
class UserRequests(BaseModel): class UserRequests(DbBaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
slug = models.CharField(max_length=200) slug = models.CharField(max_length=200)
class DataStore(BaseModel): class DataStore(DbBaseModel):
key = models.CharField(max_length=200, unique=True) key = models.CharField(max_length=200, unique=True)
value = models.JSONField(default=dict) value = models.JSONField(default=dict)
private = models.BooleanField(default=False) private = models.BooleanField(default=False)

View file

@ -1,6 +1,6 @@
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, Optional from typing import Dict, List, Optional
import pyjson5 import pyjson5
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
@ -23,7 +23,7 @@ from khoj.utils.helpers import (
is_none_or_empty, is_none_or_empty,
truncate_code_context, truncate_code_context,
) )
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import FileAttachment, LocationData
from khoj.utils.yaml import yaml_dump from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,7 +55,7 @@ def extract_questions_anthropic(
[ [
f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n' f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:] for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type") if chat["by"] == "khoj"
] ]
) )
@ -157,6 +157,10 @@ def converse_anthropic(
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
query_files: str = None, query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: Optional[str] = None,
program_execution_context: Optional[List[str]] = None,
tracer: dict = {}, tracer: dict = {},
): ):
""" """
@ -217,6 +221,10 @@ def converse_anthropic(
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.ANTHROPIC, model_type=ChatModelOptions.ModelType.ANTHROPIC,
query_files=query_files, query_files=query_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_files=generated_files,
generated_images=generated_images,
program_execution_context=program_execution_context,
) )
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)

View file

@ -1,6 +1,6 @@
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, Optional from typing import Dict, List, Optional
import pyjson5 import pyjson5
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
@ -23,7 +23,7 @@ from khoj.utils.helpers import (
is_none_or_empty, is_none_or_empty,
truncate_code_context, truncate_code_context,
) )
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import FileAttachment, LocationData
from khoj.utils.yaml import yaml_dump from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ def extract_questions_gemini(
[ [
f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n' f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:] for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type") if chat["by"] == "khoj"
] ]
) )
@ -167,6 +167,10 @@ def converse_gemini(
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
query_files: str = None, query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: Optional[str] = None,
program_execution_context: List[str] = None,
tracer={}, tracer={},
): ):
""" """
@ -228,6 +232,10 @@ def converse_gemini(
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE, model_type=ChatModelOptions.ModelType.GOOGLE,
query_files=query_files, query_files=query_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_files=generated_files,
generated_images=generated_images,
program_execution_context=program_execution_context,
) )
messages, system_prompt = format_messages_for_gemini(messages, system_prompt) messages, system_prompt = format_messages_for_gemini(messages, system_prompt)

View file

@ -28,7 +28,7 @@ from khoj.utils.helpers import (
is_promptrace_enabled, is_promptrace_enabled,
truncate_code_context, truncate_code_context,
) )
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import FileAttachment, LocationData
from khoj.utils.yaml import yaml_dump from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -69,7 +69,7 @@ def extract_questions_offline(
if use_history: if use_history:
for chat in conversation_log.get("chat", [])[-4:]: for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type"): if chat["by"] == "khoj":
chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"Khoj: {chat['message']}\n\n" chat_history += f"Khoj: {chat['message']}\n\n"
@ -164,6 +164,8 @@ def converse_offline(
user_name: str = None, user_name: str = None,
agent: Agent = None, agent: Agent = None,
query_files: str = None, query_files: str = None,
generated_files: List[FileAttachment] = None,
additional_context: List[str] = None,
tracer: dict = {}, tracer: dict = {},
) -> Union[ThreadedGenerator, Iterator[str]]: ) -> Union[ThreadedGenerator, Iterator[str]]:
""" """
@ -231,6 +233,8 @@ def converse_offline(
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.OFFLINE, model_type=ChatModelOptions.ModelType.OFFLINE,
query_files=query_files, query_files=query_files,
generated_files=generated_files,
program_execution_context=additional_context,
) )
logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}") logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}")

View file

@ -1,6 +1,6 @@
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, Optional from typing import Dict, List, Optional
import pyjson5 import pyjson5
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
@ -22,7 +22,7 @@ from khoj.utils.helpers import (
is_none_or_empty, is_none_or_empty,
truncate_code_context, truncate_code_context,
) )
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import FileAttachment, LocationData
from khoj.utils.yaml import yaml_dump from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -157,6 +157,10 @@ def converse(
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
query_files: str = None, query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: Optional[str] = None,
program_execution_context: List[str] = None,
tracer: dict = {}, tracer: dict = {},
): ):
""" """
@ -219,6 +223,10 @@ def converse(
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI, model_type=ChatModelOptions.ModelType.OPENAI,
query_files=query_files, query_files=query_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_files=generated_files,
generated_images=generated_images,
program_execution_context=program_execution_context,
) )
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")

View file

@ -178,6 +178,18 @@ Improved Prompt:
""".strip() """.strip()
) )
generated_image_attachment = PromptTemplate.from_template(
f"""
Here is the image you generated based on my query. You can follow-up with a general response to my query. Limit to 1-2 sentences.
""".strip()
)
generated_diagram_attachment = PromptTemplate.from_template(
f"""
I've successfully created a diagram based on the user's query. The diagram will automatically be shared with the user. I can follow-up with a general response or summary. Limit to 1-2 sentences.
""".strip()
)
## Diagram Generation ## Diagram Generation
## -- ## --
@ -1029,6 +1041,12 @@ A:
""".strip() """.strip()
) )
additional_program_context = PromptTemplate.from_template(
"""
Here are some additional results from the query execution:
{context}
""".strip()
)
personality_prompt_safety_expert_lax = PromptTemplate.from_template( personality_prompt_safety_expert_lax = PromptTemplate.from_template(
""" """

View file

@ -154,7 +154,7 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n' chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
chat_history += f"{agent_name}: {chat['message']}\n\n" chat_history += f"{agent_name}: {chat['message']}\n\n"
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")): elif chat["by"] == "khoj" and chat.get("images"):
chat_history += f"User: {chat['intent']['query']}\n" chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: [generated image redacted for space]\n" chat_history += f"{agent_name}: [generated image redacted for space]\n"
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")): elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
@ -213,6 +213,7 @@ class ChatEvent(Enum):
END_LLM_RESPONSE = "end_llm_response" END_LLM_RESPONSE = "end_llm_response"
MESSAGE = "message" MESSAGE = "message"
REFERENCES = "references" REFERENCES = "references"
GENERATED_ASSETS = "generated_assets"
STATUS = "status" STATUS = "status"
METADATA = "metadata" METADATA = "metadata"
USAGE = "usage" USAGE = "usage"
@ -225,7 +226,6 @@ def message_to_log(
user_message_metadata={}, user_message_metadata={},
khoj_message_metadata={}, khoj_message_metadata={},
conversation_log=[], conversation_log=[],
train_of_thought=[],
): ):
"""Create json logs from messages, metadata for conversation log""" """Create json logs from messages, metadata for conversation log"""
default_khoj_message_metadata = { default_khoj_message_metadata = {
@ -234,6 +234,10 @@ def message_to_log(
} }
khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Filter out any fields that are set to None
user_message_metadata = {k: v for k, v in user_message_metadata.items() if v is not None}
khoj_message_metadata = {k: v for k, v in khoj_message_metadata.items() if v is not None}
# Create json log from Human's message # Create json log from Human's message
human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata) human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
@ -261,31 +265,41 @@ def save_to_conversation_log(
automation_id: str = None, automation_id: str = None,
query_images: List[str] = None, query_images: List[str] = None,
raw_query_files: List[FileAttachment] = [], raw_query_files: List[FileAttachment] = [],
generated_images: List[str] = [],
raw_generated_files: List[FileAttachment] = [],
generated_excalidraw_diagram: str = None,
train_of_thought: List[Any] = [], train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {}, tracer: Dict[str, Any] = {},
): ):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
turn_id = tracer.get("mid") or str(uuid.uuid4()) turn_id = tracer.get("mid") or str(uuid.uuid4())
user_message_metadata = {"created": user_message_time, "images": query_images, "turnId": turn_id}
if raw_query_files and len(raw_query_files) > 0:
user_message_metadata["queryFiles"] = [file.model_dump(mode="json") for file in raw_query_files]
khoj_message_metadata = {
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
"onlineContext": online_results,
"codeContext": code_results,
"automationId": automation_id,
"trainOfThought": train_of_thought,
"turnId": turn_id,
"images": generated_images,
"queryFiles": [file.model_dump(mode="json") for file in raw_generated_files],
}
if generated_excalidraw_diagram:
khoj_message_metadata["excalidrawDiagram"] = generated_excalidraw_diagram
updated_conversation = message_to_log( updated_conversation = message_to_log(
user_message=q, user_message=q,
chat_response=chat_response, chat_response=chat_response,
user_message_metadata={ user_message_metadata=user_message_metadata,
"created": user_message_time, khoj_message_metadata=khoj_message_metadata,
"images": query_images,
"turnId": turn_id,
"queryFiles": [file.model_dump(mode="json") for file in raw_query_files],
},
khoj_message_metadata={
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
"onlineContext": online_results,
"codeContext": code_results,
"automationId": automation_id,
"trainOfThought": train_of_thought,
"turnId": turn_id,
},
conversation_log=meta_log.get("chat", []), conversation_log=meta_log.get("chat", []),
train_of_thought=train_of_thought,
) )
ConversationAdapters.save_conversation( ConversationAdapters.save_conversation(
user, user,
@ -303,13 +317,13 @@ def save_to_conversation_log(
Saved Conversation Turn Saved Conversation Turn
You ({user.username}): "{q}" You ({user.username}): "{q}"
Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}" Khoj: "{chat_response}"
""".strip() """.strip()
) )
def construct_structured_message( def construct_structured_message(
message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str = None
): ):
""" """
Format messages into appropriate multimedia format for supported chat model types Format messages into appropriate multimedia format for supported chat model types
@ -327,7 +341,8 @@ def construct_structured_message(
constructed_messages.append({"type": "text", "text": attached_file_context}) constructed_messages.append({"type": "text", "text": attached_file_context})
if vision_enabled and images: if vision_enabled and images:
for image in images: for image in images:
constructed_messages.append({"type": "image_url", "image_url": {"url": image}}) if image.startswith("https://"):
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
return constructed_messages return constructed_messages
if not is_none_or_empty(attached_file_context): if not is_none_or_empty(attached_file_context):
@ -365,6 +380,10 @@ def generate_chatml_messages_with_context(
model_type="", model_type="",
context_message="", context_message="",
query_files: str = None, query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: str = None,
program_execution_context: List[str] = [],
): ):
"""Generate chat messages with appropriate context from previous conversation to send to the chat model""" """Generate chat messages with appropriate context from previous conversation to send to the chat model"""
# Set max prompt size from user config or based on pre-configured for model and machine specs # Set max prompt size from user config or based on pre-configured for model and machine specs
@ -384,6 +403,7 @@ def generate_chatml_messages_with_context(
message_attached_files = "" message_attached_files = ""
chat_message = chat.get("message") chat_message = chat.get("message")
role = "user" if chat["by"] == "you" else "assistant"
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""): if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
chat_message = chat["intent"].get("inferred-queries")[0] chat_message = chat["intent"].get("inferred-queries")[0]
@ -404,7 +424,7 @@ def generate_chatml_messages_with_context(
query_files_dict[file["name"]] = file["content"] query_files_dict[file["name"]] = file["content"]
message_attached_files = gather_raw_query_files(query_files_dict) message_attached_files = gather_raw_query_files(query_files_dict)
chatml_messages.append(ChatMessage(content=message_attached_files, role="user")) chatml_messages.append(ChatMessage(content=message_attached_files, role=role))
if not is_none_or_empty(chat.get("onlineContext")): if not is_none_or_empty(chat.get("onlineContext")):
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
@ -413,10 +433,20 @@ def generate_chatml_messages_with_context(
reconstructed_context_message = ChatMessage(content=message_context, role="user") reconstructed_context_message = ChatMessage(content=message_context, role="user")
chatml_messages.insert(0, reconstructed_context_message) chatml_messages.insert(0, reconstructed_context_message)
role = "user" if chat["by"] == "you" else "assistant" if chat.get("images"):
message_content = construct_structured_message( if role == "assistant":
chat_message, chat.get("images"), model_type, vision_enabled, attached_file_context=query_files # Issue: the assistant role cannot accept an image as a message content, so send it in a separate user message.
) file_attachment_message = construct_structured_message(
message=prompts.generated_image_attachment.format(),
images=chat.get("images"),
model_type=model_type,
vision_enabled=vision_enabled,
)
chatml_messages.append(ChatMessage(content=file_attachment_message, role="user"))
else:
message_content = construct_structured_message(
chat_message, chat.get("images"), model_type, vision_enabled
)
reconstructed_message = ChatMessage(content=message_content, role=role) reconstructed_message = ChatMessage(content=message_content, role=role)
chatml_messages.insert(0, reconstructed_message) chatml_messages.insert(0, reconstructed_message)
@ -425,6 +455,7 @@ def generate_chatml_messages_with_context(
break break
messages = [] messages = []
if not is_none_or_empty(user_message): if not is_none_or_empty(user_message):
messages.append( messages.append(
ChatMessage( ChatMessage(
@ -437,6 +468,31 @@ def generate_chatml_messages_with_context(
if not is_none_or_empty(context_message): if not is_none_or_empty(context_message):
messages.append(ChatMessage(content=context_message, role="user")) messages.append(ChatMessage(content=context_message, role="user"))
if generated_images:
messages.append(
ChatMessage(
content=construct_structured_message(
prompts.generated_image_attachment.format(), generated_images, model_type, vision_enabled
),
role="user",
)
)
if generated_files:
message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
messages.append(ChatMessage(content=message_attached_files, role="assistant"))
if generated_excalidraw_diagram:
messages.append(ChatMessage(content=prompts.generated_diagram_attachment.format(), role="assistant"))
if program_execution_context:
messages.append(
ChatMessage(
content=prompts.additional_program_context.format(context="\n".join(program_execution_context)),
role="assistant",
)
)
if len(chatml_messages) > 0: if len(chatml_messages) > 0:
messages += chatml_messages messages += chatml_messages

View file

@ -12,7 +12,7 @@ from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
from khoj.routers.storage import upload_image from khoj.routers.storage import upload_image
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import ImageIntentType, convert_image_to_webp, timer from khoj.utils.helpers import convert_image_to_webp, timer
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,14 +34,13 @@ async def text_to_image(
status_code = 200 status_code = 200
image = None image = None
image_url = None image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user) text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
if not text_to_image_config: if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error # If the user has not configured a text to image model, return an unsupported on server error
status_code = 501 status_code = 501
message = "Failed to generate image. Setup image generation on the server." message = "Failed to generate image. Setup image generation on the server."
yield image_url or image, status_code, message, intent_type.value yield image_url or image, status_code, message
return return
text2image_model = text_to_image_config.model_name text2image_model = text_to_image_config.model_name
@ -50,8 +49,8 @@ async def text_to_image(
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]: if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n" chat_history += f"A: {chat['message']}\n"
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"): elif chat["by"] == "khoj" and chat.get("images"):
chat_history += f"Q: Prompt: {chat['intent']['query']}\n" chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n" chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
if send_status_func: if send_status_func:
@ -92,31 +91,29 @@ async def text_to_image(
logger.error(f"Image Generation blocked by OpenAI: {e}") logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore
yield image_url or image, status_code, message, intent_type.value yield image_url or image, status_code, message
return return
else: else:
logger.error(f"Image Generation failed with {e}", exc_info=True) logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed using OpenAI" # type: ignore message = f"Image generation failed using OpenAI" # type: ignore
status_code = e.status_code # type: ignore status_code = e.status_code # type: ignore
yield image_url or image, status_code, message, intent_type.value yield image_url or image, status_code, message
return return
except requests.RequestException as e: except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True) logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error." message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error."
status_code = 502 status_code = 502
yield image_url or image, status_code, message, intent_type.value yield image_url or image, status_code, message
return return
# Decide how to store the generated image # Decide how to store the generated image
with timer("Upload image to S3", logger): with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid) image_url = upload_image(webp_image_bytes, user.uuid)
if image_url:
intent_type = ImageIntentType.TEXT_TO_IMAGE2 if not image_url:
else:
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8") image = base64.b64encode(webp_image_bytes).decode("utf-8")
yield image_url or image, status_code, image_prompt, intent_type.value yield image_url or image, status_code, image_prompt
def generate_image_with_openai( def generate_image_with_openai(

View file

@ -77,6 +77,7 @@ from khoj.utils.helpers import (
) )
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import (
ChatRequestBody, ChatRequestBody,
FileAttachment,
FileFilterRequest, FileFilterRequest,
FilesFilterRequest, FilesFilterRequest,
LocationData, LocationData,
@ -770,6 +771,11 @@ async def chat(
file_filters = conversation.file_filters if conversation and conversation.file_filters else [] file_filters = conversation.file_filters if conversation and conversation.file_filters else []
attached_file_context = gather_raw_query_files(query_files) attached_file_context = gather_raw_query_files(query_files)
generated_images: List[str] = []
generated_files: List[FileAttachment] = []
generated_excalidraw_diagram: str = None
program_execution_context: List[str] = []
if conversation_commands == [ConversationCommand.Default] or is_automated_task: if conversation_commands == [ConversationCommand.Default] or is_automated_task:
chosen_io = await aget_data_sources_and_output_format( chosen_io = await aget_data_sources_and_output_format(
q, q,
@ -875,21 +881,17 @@ async def chat(
async for result in send_llm_response(response, tracer.get("usage")): async for result in send_llm_response(response, tracer.get("usage")):
yield result yield result
await sync_to_async(save_to_conversation_log)( summarized_document = FileAttachment(
q, name="Summarized Document",
response_log, content=response_log,
user, type="text/plain",
meta_log, size=len(response_log.encode("utf-8")),
user_message_time,
intent_type="summarize",
client_application=request.user.client_app,
conversation_id=conversation_id,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
) )
return
async for result in send_event(ChatEvent.GENERATED_ASSETS, {"files": [summarized_document.model_dump()]}):
yield result
generated_files.append(summarized_document)
custom_filters = [] custom_filters = []
if conversation_commands == [ConversationCommand.Help]: if conversation_commands == [ConversationCommand.Help]:
@ -1078,6 +1080,7 @@ async def chat(
async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"): async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"):
yield result yield result
except ValueError as e: except ValueError as e:
program_execution_context.append(f"Failed to run code")
logger.warning( logger.warning(
f"Failed to use code tool: {e}. Attempting to respond without code results", f"Failed to use code tool: {e}. Attempting to respond without code results",
exc_info=True, exc_info=True,
@ -1115,51 +1118,28 @@ async def chat(
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
else: else:
generated_image, status_code, improved_image_prompt, intent_type = result generated_image, status_code, improved_image_prompt = result
inferred_queries.append(improved_image_prompt)
if generated_image is None or status_code != 200: if generated_image is None or status_code != 200:
content_obj = { program_execution_context.append(f"Failed to generate image with {improved_image_prompt}")
"content-type": "application/json", async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"):
"intentType": intent_type,
"detail": improved_image_prompt,
"image": None,
}
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
yield result yield result
return else:
generated_images.append(generated_image)
await sync_to_async(save_to_conversation_log)( async for result in send_event(
q, ChatEvent.GENERATED_ASSETS,
generated_image, {
user, "images": [generated_image],
meta_log, },
user_message_time, ):
intent_type=intent_type, yield result
inferred_queries=[improved_image_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
)
content_obj = {
"intentType": intent_type,
"inferredQueries": [improved_image_prompt],
"image": generated_image,
}
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
yield result
return
if ConversationCommand.Diagram in conversation_commands: if ConversationCommand.Diagram in conversation_commands:
async for result in send_event(ChatEvent.STATUS, f"Creating diagram"): async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
yield result yield result
intent_type = "excalidraw"
inferred_queries = [] inferred_queries = []
diagram_description = "" diagram_description = ""
@ -1183,62 +1163,29 @@ async def chat(
if better_diagram_description_prompt and excalidraw_diagram_description: if better_diagram_description_prompt and excalidraw_diagram_description:
inferred_queries.append(better_diagram_description_prompt) inferred_queries.append(better_diagram_description_prompt)
diagram_description = excalidraw_diagram_description diagram_description = excalidraw_diagram_description
generated_excalidraw_diagram = diagram_description
async for result in send_event(
ChatEvent.GENERATED_ASSETS,
{
"excalidrawDiagram": excalidraw_diagram_description,
},
):
yield result
else: else:
error_message = "Failed to generate diagram. Please try again later." error_message = "Failed to generate diagram. Please try again later."
async for result in send_llm_response(error_message, tracer.get("usage")): program_execution_context.append(
yield result f"AI attempted to programmatically generate a diagram but failed due to a program issue. Generally, it is able to do so, but encountered a system issue this time. AI can suggest text description or rendering of the diagram or user can try again with a simpler prompt."
await sync_to_async(save_to_conversation_log)(
q,
error_message,
user,
meta_log,
user_message_time,
inferred_queries=[better_diagram_description_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
) )
return
content_obj = { async for result in send_event(ChatEvent.STATUS, error_message):
"intentType": intent_type, yield result
"inferredQueries": inferred_queries,
"image": diagram_description,
}
await sync_to_async(save_to_conversation_log)(
q,
excalidraw_diagram_description,
user,
meta_log,
user_message_time,
intent_type="excalidraw",
inferred_queries=[better_diagram_description_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
)
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
yield result
return
## Generate Text Output ## Generate Text Output
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
yield result yield result
llm_response, chat_metadata = await agenerate_chat_response( llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query, defiltered_query,
meta_log, meta_log,
@ -1258,6 +1205,10 @@ async def chat(
train_of_thought, train_of_thought,
attached_file_context, attached_file_context,
raw_query_files, raw_query_files,
generated_images,
generated_files,
generated_excalidraw_diagram,
program_execution_context,
tracer, tracer,
) )

View file

@ -1185,6 +1185,10 @@ def generate_chat_response(
train_of_thought: List[Any] = [], train_of_thought: List[Any] = [],
query_files: str = None, query_files: str = None,
raw_query_files: List[FileAttachment] = None, raw_query_files: List[FileAttachment] = None,
generated_images: List[str] = None,
raw_generated_files: List[FileAttachment] = [],
generated_excalidraw_diagram: str = None,
program_execution_context: List[str] = [],
tracer: dict = {}, tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables # Initialize Variables
@ -1208,6 +1212,9 @@ def generate_chat_response(
query_images=query_images, query_images=query_images,
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
raw_query_files=raw_query_files, raw_query_files=raw_query_files,
generated_images=generated_images,
raw_generated_files=raw_generated_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
tracer=tracer, tracer=tracer,
) )
@ -1243,6 +1250,7 @@ def generate_chat_response(
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
query_files=query_files, query_files=query_files,
generated_files=raw_generated_files,
tracer=tracer, tracer=tracer,
) )
@ -1269,6 +1277,10 @@ def generate_chat_response(
agent=agent, agent=agent,
vision_available=vision_available, vision_available=vision_available,
query_files=query_files, query_files=query_files,
generated_files=raw_generated_files,
generated_images=generated_images,
generated_excalidraw_diagram=generated_excalidraw_diagram,
program_execution_context=program_execution_context,
tracer=tracer, tracer=tracer,
) )
@ -1292,6 +1304,10 @@ def generate_chat_response(
agent=agent, agent=agent,
vision_available=vision_available, vision_available=vision_available,
query_files=query_files, query_files=query_files,
generated_files=raw_generated_files,
generated_images=generated_images,
generated_excalidraw_diagram=generated_excalidraw_diagram,
program_execution_context=program_execution_context,
tracer=tracer, tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
@ -1314,6 +1330,10 @@ def generate_chat_response(
query_images=query_images, query_images=query_images,
vision_available=vision_available, vision_available=vision_available,
query_files=query_files, query_files=query_files,
generated_files=raw_generated_files,
generated_images=generated_images,
generated_excalidraw_diagram=generated_excalidraw_diagram,
program_execution_context=program_execution_context,
tracer=tracer, tracer=tracer,
) )
@ -1785,6 +1805,9 @@ class MessageProcessor:
self.references = {} self.references = {}
self.usage = {} self.usage = {}
self.raw_response = "" self.raw_response = ""
self.generated_images = []
self.generated_files = []
self.generated_excalidraw_diagram = []
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]: def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
if raw_chunk.startswith("{") and raw_chunk.endswith("}"): if raw_chunk.startswith("{") and raw_chunk.endswith("}"):
@ -1823,6 +1846,16 @@ class MessageProcessor:
self.raw_response += chunk_data self.raw_response += chunk_data
else: else:
self.raw_response += chunk_data self.raw_response += chunk_data
elif chunk_type == ChatEvent.GENERATED_ASSETS:
chunk_data = chunk["data"]
if isinstance(chunk_data, dict):
for key in chunk_data:
if key == "images":
self.generated_images = chunk_data[key]
elif key == "files":
self.generated_files = chunk_data[key]
elif key == "excalidrawDiagram":
self.generated_excalidraw_diagram = chunk_data[key]
def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]: def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]:
if "image" in json_data or "details" in json_data: if "image" in json_data or "details" in json_data:
@ -1853,7 +1886,14 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
if buffer: if buffer:
processor.process_message_chunk(buffer) processor.process_message_chunk(buffer)
return {"response": processor.raw_response, "references": processor.references, "usage": processor.usage} return {
"response": processor.raw_response,
"references": processor.references,
"usage": processor.usage,
"images": processor.generated_images,
"files": processor.generated_files,
"excalidrawDiagram": processor.generated_excalidraw_diagram,
}
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False): def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):

View file

@ -22,7 +22,6 @@ from khoj.processor.conversation.offline.chat_model import (
filter_questions, filter_questions,
) )
from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import message_to_log
from khoj.utils.constants import default_offline_chat_models from khoj.utils.constants import default_offline_chat_models

View file

@ -6,7 +6,6 @@ from freezegun import freeze_time
from khoj.database.models import Agent, Entry, KhojUser from khoj.database.models import Agent, Entry, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_api_key from tests.helpers import ConversationFactory, generate_chat_history, get_chat_api_key
# Initialize variables for tests # Initialize variables for tests