mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-12-18 10:37:11 +00:00
Merge pull request #992 from khoj-ai/features/allow-multi-outputs-in-chat
Currently, Khoj has terminal states with respect to what assets it outputs. We limit it to image, text, and excalidraw diagram. This limitation is unnecessary and provides undue constraints for creating more dynamic, diverse experiences. For instance, we may want the chat view to morph for document editing or generation, in which case having limited output modes would be a detriment. This change allows us to emit generated assets and then continue on to more text generation in final response. It forces text response for all messages. It adds a new stream event, GENERATED_ASSETS, which holds content that the AI is emitting in response to the query.
This commit is contained in:
commit
e3789aef49
21 changed files with 665 additions and 285 deletions
8
.github/workflows/run_evals.yml
vendored
8
.github/workflows/run_evals.yml
vendored
|
@ -76,12 +76,12 @@ jobs:
|
|||
DEBIAN_FRONTEND: noninteractive
|
||||
run: |
|
||||
# install postgres and other dependencies
|
||||
apt update && apt install -y git python3-pip libegl1 sqlite3 libsqlite3-dev libsqlite3-0 ffmpeg libsm6 libxext6
|
||||
apt install -y postgresql postgresql-client && apt install -y postgresql-server-dev-14
|
||||
sudo apt update && sudo apt install -y git python3-pip libegl1 sqlite3 libsqlite3-dev libsqlite3-0 ffmpeg libsm6 libxext6
|
||||
sudo apt install -y postgresql postgresql-client && sudo apt install -y postgresql-server-dev-14
|
||||
# upgrade pip
|
||||
python -m ensurepip --upgrade && python -m pip install --upgrade pip
|
||||
# install terrarium for code sandbox
|
||||
git clone https://github.com/cohere-ai/cohere-terrarium.git && cd cohere-terrarium && npm install && mkdir pyodide_cache
|
||||
git clone https://github.com/khoj-ai/terrarium.git && cd terrarium && npm install --legacy-peer-deps && mkdir pyodide_cache
|
||||
|
||||
- name: ⬇️ Install Application
|
||||
run: |
|
||||
|
@ -113,7 +113,7 @@ jobs:
|
|||
khoj --anonymous-mode --non-interactive &
|
||||
|
||||
# Start code sandbox
|
||||
npm run dev --prefix cohere-terrarium &
|
||||
npm run dev --prefix terrarium &
|
||||
|
||||
# Wait for server to be ready
|
||||
timeout=120
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import {ItemView, MarkdownRenderer, Scope, WorkspaceLeaf, request, requestUrl, setIcon, Platform} from 'obsidian';
|
||||
import { ItemView, MarkdownRenderer, Scope, WorkspaceLeaf, request, requestUrl, setIcon, Platform } from 'obsidian';
|
||||
import * as DOMPurify from 'dompurify';
|
||||
import { KhojSetting } from 'src/settings';
|
||||
import { KhojPaneView } from 'src/pane_view';
|
||||
|
@ -27,6 +27,7 @@ interface ChatMessageState {
|
|||
newResponseEl: HTMLElement | null;
|
||||
loadingEllipsis: HTMLElement | null;
|
||||
references: any;
|
||||
generatedAssets: string;
|
||||
rawResponse: string;
|
||||
rawQuery: string;
|
||||
isVoice: boolean;
|
||||
|
@ -46,10 +47,10 @@ export class KhojChatView extends KhojPaneView {
|
|||
waitingForLocation: boolean;
|
||||
location: Location = { timezone: Intl.DateTimeFormat().resolvedOptions().timeZone };
|
||||
keyPressTimeout: NodeJS.Timeout | null = null;
|
||||
userMessages: string[] = []; // Store user sent messages for input history cycling
|
||||
currentMessageIndex: number = -1; // Track current message index in userMessages array
|
||||
private currentUserInput: string = ""; // Stores the current user input that is being typed in chat
|
||||
private startingMessage: string = "Message";
|
||||
userMessages: string[] = []; // Store user sent messages for input history cycling
|
||||
currentMessageIndex: number = -1; // Track current message index in userMessages array
|
||||
private currentUserInput: string = ""; // Stores the current user input that is being typed in chat
|
||||
private startingMessage: string = "Message";
|
||||
chatMessageState: ChatMessageState;
|
||||
|
||||
constructor(leaf: WorkspaceLeaf, setting: KhojSetting) {
|
||||
|
@ -102,14 +103,14 @@ export class KhojChatView extends KhojPaneView {
|
|||
|
||||
// Clear text after extracting message to send
|
||||
let user_message = input_el.value.trim();
|
||||
// Store the message in the array if it's not empty
|
||||
if (user_message) {
|
||||
this.userMessages.push(user_message);
|
||||
// Update starting message after sending a new message
|
||||
const modifierKey = Platform.isMacOS ? '⌘' : '^';
|
||||
this.startingMessage = `(${modifierKey}+↑/↓) for prev messages`;
|
||||
input_el.placeholder = this.startingMessage;
|
||||
}
|
||||
// Store the message in the array if it's not empty
|
||||
if (user_message) {
|
||||
this.userMessages.push(user_message);
|
||||
// Update starting message after sending a new message
|
||||
const modifierKey = Platform.isMacOS ? '⌘' : '^';
|
||||
this.startingMessage = `(${modifierKey}+↑/↓) for prev messages`;
|
||||
input_el.placeholder = this.startingMessage;
|
||||
}
|
||||
input_el.value = "";
|
||||
this.autoResize();
|
||||
|
||||
|
@ -162,9 +163,9 @@ export class KhojChatView extends KhojPaneView {
|
|||
})
|
||||
chatInput.addEventListener('input', (_) => { this.onChatInput() });
|
||||
chatInput.addEventListener('keydown', (event) => {
|
||||
this.incrementalChat(event);
|
||||
this.handleArrowKeys(event);
|
||||
});
|
||||
this.incrementalChat(event);
|
||||
this.handleArrowKeys(event);
|
||||
});
|
||||
|
||||
// Add event listeners for long press keybinding
|
||||
this.contentEl.addEventListener('keydown', this.handleKeyDown.bind(this));
|
||||
|
@ -199,7 +200,7 @@ export class KhojChatView extends KhojPaneView {
|
|||
// Get chat history from Khoj backend and set chat input state
|
||||
let getChatHistorySucessfully = await this.getChatHistory(chatBodyEl);
|
||||
|
||||
let placeholderText : string = getChatHistorySucessfully ? this.startingMessage : "Configure Khoj to enable chat";
|
||||
let placeholderText: string = getChatHistorySucessfully ? this.startingMessage : "Configure Khoj to enable chat";
|
||||
chatInput.placeholder = placeholderText;
|
||||
chatInput.disabled = !getChatHistorySucessfully;
|
||||
|
||||
|
@ -214,7 +215,7 @@ export class KhojChatView extends KhojPaneView {
|
|||
});
|
||||
}
|
||||
|
||||
startSpeechToText(event: KeyboardEvent | MouseEvent | TouchEvent, timeout=200) {
|
||||
startSpeechToText(event: KeyboardEvent | MouseEvent | TouchEvent, timeout = 200) {
|
||||
if (!this.keyPressTimeout) {
|
||||
this.keyPressTimeout = setTimeout(async () => {
|
||||
// Reset auto send voice message timer, UI if running
|
||||
|
@ -320,7 +321,7 @@ export class KhojChatView extends KhojPaneView {
|
|||
referenceButton.tabIndex = 0;
|
||||
|
||||
// Add event listener to toggle full reference on click
|
||||
referenceButton.addEventListener('click', function() {
|
||||
referenceButton.addEventListener('click', function () {
|
||||
if (this.classList.contains("collapsed")) {
|
||||
this.classList.remove("collapsed");
|
||||
this.classList.add("expanded");
|
||||
|
@ -375,7 +376,7 @@ export class KhojChatView extends KhojPaneView {
|
|||
referenceButton.tabIndex = 0;
|
||||
|
||||
// Add event listener to toggle full reference on click
|
||||
referenceButton.addEventListener('click', function() {
|
||||
referenceButton.addEventListener('click', function () {
|
||||
if (this.classList.contains("collapsed")) {
|
||||
this.classList.remove("collapsed");
|
||||
this.classList.add("expanded");
|
||||
|
@ -420,23 +421,23 @@ export class KhojChatView extends KhojPaneView {
|
|||
"Authorization": `Bearer ${this.setting.khojApiKey}`,
|
||||
},
|
||||
})
|
||||
.then(response => response.arrayBuffer())
|
||||
.then(arrayBuffer => context.decodeAudioData(arrayBuffer))
|
||||
.then(audioBuffer => {
|
||||
const source = context.createBufferSource();
|
||||
source.buffer = audioBuffer;
|
||||
source.connect(context.destination);
|
||||
source.start(0);
|
||||
source.onended = function() {
|
||||
.then(response => response.arrayBuffer())
|
||||
.then(arrayBuffer => context.decodeAudioData(arrayBuffer))
|
||||
.then(audioBuffer => {
|
||||
const source = context.createBufferSource();
|
||||
source.buffer = audioBuffer;
|
||||
source.connect(context.destination);
|
||||
source.start(0);
|
||||
source.onended = function () {
|
||||
speechButton.removeChild(loader);
|
||||
speechButton.disabled = false;
|
||||
};
|
||||
})
|
||||
.catch(err => {
|
||||
console.error("Error playing speech:", err);
|
||||
speechButton.removeChild(loader);
|
||||
speechButton.disabled = false;
|
||||
};
|
||||
})
|
||||
.catch(err => {
|
||||
console.error("Error playing speech:", err);
|
||||
speechButton.removeChild(loader);
|
||||
speechButton.disabled = false; // Consider enabling the button again to allow retrying
|
||||
});
|
||||
speechButton.disabled = false; // Consider enabling the button again to allow retrying
|
||||
});
|
||||
}
|
||||
|
||||
formatHTMLMessage(message: string, raw = false, willReplace = true) {
|
||||
|
@ -485,12 +486,18 @@ export class KhojChatView extends KhojPaneView {
|
|||
intentType?: string,
|
||||
inferredQueries?: string[],
|
||||
conversationId?: string,
|
||||
images?: string[],
|
||||
excalidrawDiagram?: string
|
||||
) {
|
||||
if (!message) return;
|
||||
|
||||
let chatMessageEl;
|
||||
if (intentType?.includes("text-to-image") || intentType === "excalidraw") {
|
||||
let imageMarkdown = this.generateImageMarkdown(message, intentType, inferredQueries, conversationId);
|
||||
if (
|
||||
intentType?.includes("text-to-image") ||
|
||||
intentType === "excalidraw" ||
|
||||
(images && images.length > 0) ||
|
||||
excalidrawDiagram) {
|
||||
let imageMarkdown = this.generateImageMarkdown(message, intentType ?? "", inferredQueries, conversationId, images, excalidrawDiagram);
|
||||
chatMessageEl = this.renderMessage(chatEl, imageMarkdown, sender, dt);
|
||||
} else {
|
||||
chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
|
||||
|
@ -510,7 +517,7 @@ export class KhojChatView extends KhojPaneView {
|
|||
chatMessageBodyEl.appendChild(this.createReferenceSection(references));
|
||||
}
|
||||
|
||||
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string): string {
|
||||
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string, images?: string[], excalidrawDiagram?: string): string {
|
||||
let imageMarkdown = "";
|
||||
if (intentType === "text-to-image") {
|
||||
imageMarkdown = `![](data:image/png;base64,${message})`;
|
||||
|
@ -518,12 +525,23 @@ export class KhojChatView extends KhojPaneView {
|
|||
imageMarkdown = `![](${message})`;
|
||||
} else if (intentType === "text-to-image-v3") {
|
||||
imageMarkdown = `![](data:image/webp;base64,${message})`;
|
||||
} else if (intentType === "excalidraw") {
|
||||
} else if (intentType === "excalidraw" || excalidrawDiagram) {
|
||||
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
|
||||
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}chat?conversationId=${conversationId}`;
|
||||
imageMarkdown = redirectMessage;
|
||||
} else if (images && images.length > 0) {
|
||||
for (let image of images) {
|
||||
if (image.startsWith("https://")) {
|
||||
imageMarkdown += `![](${image})\n\n`;
|
||||
} else {
|
||||
imageMarkdown += `![](data:image/png;base64,${image})\n\n`;
|
||||
}
|
||||
}
|
||||
|
||||
imageMarkdown += `${message}`;
|
||||
}
|
||||
if (inferredQueries) {
|
||||
|
||||
if (images?.length === 0 && inferredQueries) {
|
||||
imageMarkdown += "\n\n**Inferred Query**:";
|
||||
for (let inferredQuery of inferredQueries) {
|
||||
imageMarkdown += `\n\n${inferredQuery}`;
|
||||
|
@ -650,19 +668,19 @@ export class KhojChatView extends KhojPaneView {
|
|||
chatBodyEl.innerHTML = "";
|
||||
chatBodyEl.dataset.conversationId = "";
|
||||
chatBodyEl.dataset.conversationTitle = "";
|
||||
this.userMessages = [];
|
||||
this.startingMessage = "Message";
|
||||
this.userMessages = [];
|
||||
this.startingMessage = "Message";
|
||||
|
||||
// Update the placeholder of the chat input
|
||||
const chatInput = this.contentEl.querySelector('.khoj-chat-input') as HTMLTextAreaElement;
|
||||
if (chatInput) {
|
||||
chatInput.placeholder = this.startingMessage;
|
||||
}
|
||||
// Update the placeholder of the chat input
|
||||
const chatInput = this.contentEl.querySelector('.khoj-chat-input') as HTMLTextAreaElement;
|
||||
if (chatInput) {
|
||||
chatInput.placeholder = this.startingMessage;
|
||||
}
|
||||
this.renderMessage(chatBodyEl, "Hey 👋🏾, what's up?", "khoj");
|
||||
}
|
||||
|
||||
async toggleChatSessions(forceShow: boolean = false): Promise<boolean> {
|
||||
this.userMessages = []; // clear user previous message history
|
||||
this.userMessages = []; // clear user previous message history
|
||||
let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0] as HTMLElement;
|
||||
if (!forceShow && this.contentEl.getElementsByClassName("side-panel")?.length > 0) {
|
||||
chatBodyEl.innerHTML = "";
|
||||
|
@ -768,10 +786,10 @@ export class KhojChatView extends KhojPaneView {
|
|||
let editConversationTitleInputEl = this.contentEl.createEl('input');
|
||||
editConversationTitleInputEl.classList.add("conversation-title-input");
|
||||
editConversationTitleInputEl.value = conversationTitle;
|
||||
editConversationTitleInputEl.addEventListener('click', function(event) {
|
||||
editConversationTitleInputEl.addEventListener('click', function (event) {
|
||||
event.stopPropagation();
|
||||
});
|
||||
editConversationTitleInputEl.addEventListener('keydown', function(event) {
|
||||
editConversationTitleInputEl.addEventListener('keydown', function (event) {
|
||||
if (event.key === "Enter") {
|
||||
event.preventDefault();
|
||||
editConversationTitleSaveButtonEl.click();
|
||||
|
@ -890,15 +908,17 @@ export class KhojChatView extends KhojPaneView {
|
|||
chatLog.intent?.type,
|
||||
chatLog.intent?.["inferred-queries"],
|
||||
chatBodyEl.dataset.conversationId ?? "",
|
||||
chatLog.images,
|
||||
chatLog.excalidrawDiagram,
|
||||
);
|
||||
// push the user messages to the chat history
|
||||
if(chatLog.by === "you"){
|
||||
if (chatLog.by === "you") {
|
||||
this.userMessages.push(chatLog.message);
|
||||
}
|
||||
});
|
||||
|
||||
// Update starting message after loading history
|
||||
const modifierKey: string = Platform.isMacOS ? '⌘' : '^';
|
||||
const modifierKey: string = Platform.isMacOS ? '⌘' : '^';
|
||||
this.startingMessage = this.userMessages.length > 0
|
||||
? `(${modifierKey}+↑/↓) for prev messages`
|
||||
: "Message";
|
||||
|
@ -922,15 +942,15 @@ export class KhojChatView extends KhojPaneView {
|
|||
try {
|
||||
let jsonChunk = JSON.parse(rawChunk);
|
||||
if (!jsonChunk.type)
|
||||
jsonChunk = {type: 'message', data: jsonChunk};
|
||||
jsonChunk = { type: 'message', data: jsonChunk };
|
||||
return jsonChunk;
|
||||
} catch (e) {
|
||||
return {type: 'message', data: rawChunk};
|
||||
return { type: 'message', data: rawChunk };
|
||||
}
|
||||
} else if (rawChunk.length > 0) {
|
||||
return {type: 'message', data: rawChunk};
|
||||
return { type: 'message', data: rawChunk };
|
||||
}
|
||||
return {type: '', data: ''};
|
||||
return { type: '', data: '' };
|
||||
}
|
||||
|
||||
processMessageChunk(rawChunk: string): void {
|
||||
|
@ -941,6 +961,11 @@ export class KhojChatView extends KhojPaneView {
|
|||
console.log(`status: ${chunk.data}`);
|
||||
const statusMessage = chunk.data;
|
||||
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, statusMessage, this.chatMessageState.loadingEllipsis, false);
|
||||
} else if (chunk.type === 'generated_assets') {
|
||||
const generatedAssets = chunk.data;
|
||||
const imageData = this.handleImageResponse(generatedAssets, this.chatMessageState.rawResponse);
|
||||
this.chatMessageState.generatedAssets = imageData;
|
||||
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, imageData, this.chatMessageState.loadingEllipsis, false);
|
||||
} else if (chunk.type === 'start_llm_response') {
|
||||
console.log("Started streaming", new Date());
|
||||
} else if (chunk.type === 'end_llm_response') {
|
||||
|
@ -963,9 +988,10 @@ export class KhojChatView extends KhojPaneView {
|
|||
rawResponse: "",
|
||||
rawQuery: liveQuery,
|
||||
isVoice: false,
|
||||
generatedAssets: "",
|
||||
};
|
||||
} else if (chunk.type === "references") {
|
||||
this.chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext};
|
||||
this.chatMessageState.references = { "notes": chunk.data.context, "online": chunk.data.onlineContext };
|
||||
} else if (chunk.type === 'message') {
|
||||
const chunkData = chunk.data;
|
||||
if (typeof chunkData === 'object' && chunkData !== null) {
|
||||
|
@ -978,17 +1004,17 @@ export class KhojChatView extends KhojPaneView {
|
|||
this.handleJsonResponse(jsonData);
|
||||
} catch (e) {
|
||||
this.chatMessageState.rawResponse += chunkData;
|
||||
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis);
|
||||
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse + this.chatMessageState.generatedAssets, this.chatMessageState.loadingEllipsis);
|
||||
}
|
||||
} else {
|
||||
this.chatMessageState.rawResponse += chunkData;
|
||||
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis);
|
||||
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse + this.chatMessageState.generatedAssets, this.chatMessageState.loadingEllipsis);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handleJsonResponse(jsonData: any): void {
|
||||
if (jsonData.image || jsonData.detail) {
|
||||
if (jsonData.image || jsonData.detail || jsonData.images || jsonData.excalidrawDiagram) {
|
||||
this.chatMessageState.rawResponse = this.handleImageResponse(jsonData, this.chatMessageState.rawResponse);
|
||||
} else if (jsonData.response) {
|
||||
this.chatMessageState.rawResponse = jsonData.response;
|
||||
|
@ -1234,11 +1260,11 @@ export class KhojChatView extends KhojPaneView {
|
|||
const recordingConfig = { mimeType: 'audio/webm' };
|
||||
this.mediaRecorder = new MediaRecorder(stream, recordingConfig);
|
||||
|
||||
this.mediaRecorder.addEventListener("dataavailable", function(event) {
|
||||
this.mediaRecorder.addEventListener("dataavailable", function (event) {
|
||||
if (event.data.size > 0) audioChunks.push(event.data);
|
||||
});
|
||||
|
||||
this.mediaRecorder.addEventListener("stop", async function() {
|
||||
this.mediaRecorder.addEventListener("stop", async function () {
|
||||
const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
|
||||
await sendToServer(audioBlob);
|
||||
});
|
||||
|
@ -1368,7 +1394,22 @@ export class KhojChatView extends KhojPaneView {
|
|||
if (inferredQuery) {
|
||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
}
|
||||
} else if (imageJson.images) {
|
||||
// If response has images field, response is a list of generated images.
|
||||
imageJson.images.forEach((image: any) => {
|
||||
|
||||
if (image.startsWith("http")) {
|
||||
rawResponse += `![generated_image](${image})\n\n`;
|
||||
} else {
|
||||
rawResponse += `![generated_image](data:image/png;base64,${image})\n\n`;
|
||||
}
|
||||
});
|
||||
} else if (imageJson.excalidrawDiagram) {
|
||||
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
|
||||
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}`;
|
||||
rawResponse += redirectMessage;
|
||||
}
|
||||
|
||||
// If response has detail field, response is an error message.
|
||||
if (imageJson.detail) rawResponse += imageJson.detail;
|
||||
|
||||
|
@ -1407,7 +1448,7 @@ export class KhojChatView extends KhojPaneView {
|
|||
referenceExpandButton.classList.add("reference-expand-button");
|
||||
referenceExpandButton.innerHTML = numReferences == 1 ? "1 reference" : `${numReferences} references`;
|
||||
|
||||
referenceExpandButton.addEventListener('click', function() {
|
||||
referenceExpandButton.addEventListener('click', function () {
|
||||
if (referenceSection.classList.contains("collapsed")) {
|
||||
referenceSection.classList.remove("collapsed");
|
||||
referenceSection.classList.add("expanded");
|
||||
|
|
|
@ -82,7 +82,8 @@ If your plugin does not need CSS, delete this file.
|
|||
}
|
||||
/* color chat bubble by khoj blue */
|
||||
.khoj-chat-message-text.khoj {
|
||||
border: 1px solid var(--khoj-sun);
|
||||
border-top: 1px solid var(--khoj-sun);
|
||||
border-radius: 0px;
|
||||
margin-left: auto;
|
||||
white-space: pre-line;
|
||||
}
|
||||
|
@ -104,8 +105,9 @@ If your plugin does not need CSS, delete this file.
|
|||
}
|
||||
/* color chat bubble by you dark grey */
|
||||
.khoj-chat-message-text.you {
|
||||
border: 1px solid var(--color-accent);
|
||||
color: var(--text-normal);
|
||||
margin-right: auto;
|
||||
background-color: var(--background-modifier-cover);
|
||||
}
|
||||
/* add right protrusion to you chat bubble */
|
||||
.khoj-chat-message-text.you:after {
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import { AttachedFileText } from "../components/chatInputArea/chatInputArea";
|
||||
import {
|
||||
CodeContext,
|
||||
Context,
|
||||
|
@ -16,6 +17,12 @@ export interface MessageMetadata {
|
|||
turnId: string;
|
||||
}
|
||||
|
||||
export interface GeneratedAssetsData {
|
||||
images: string[];
|
||||
excalidrawDiagram: string;
|
||||
files: AttachedFileText[];
|
||||
}
|
||||
|
||||
export interface ResponseWithIntent {
|
||||
intentType: string;
|
||||
response: string;
|
||||
|
@ -84,6 +91,8 @@ export function processMessageChunk(
|
|||
|
||||
if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext, codeContext };
|
||||
|
||||
console.log(`chunk type: ${chunk.type}`);
|
||||
|
||||
if (chunk.type === "status") {
|
||||
console.log(`status: ${chunk.data}`);
|
||||
const statusMessage = chunk.data as string;
|
||||
|
@ -98,6 +107,20 @@ export function processMessageChunk(
|
|||
} else if (chunk.type === "metadata") {
|
||||
const messageMetadata = chunk.data as MessageMetadata;
|
||||
currentMessage.turnId = messageMetadata.turnId;
|
||||
} else if (chunk.type === "generated_assets") {
|
||||
const generatedAssets = chunk.data as GeneratedAssetsData;
|
||||
|
||||
if (generatedAssets.images) {
|
||||
currentMessage.generatedImages = generatedAssets.images;
|
||||
}
|
||||
|
||||
if (generatedAssets.excalidrawDiagram) {
|
||||
currentMessage.generatedExcalidrawDiagram = generatedAssets.excalidrawDiagram;
|
||||
}
|
||||
|
||||
if (generatedAssets.files) {
|
||||
currentMessage.generatedFiles = generatedAssets.files;
|
||||
}
|
||||
} else if (chunk.type === "message") {
|
||||
const chunkData = chunk.data;
|
||||
// Here, handle if the response is a JSON response with an image, but the intentType is excalidraw
|
||||
|
|
|
@ -54,6 +54,12 @@ function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
|
|||
const lastIndex = props.trainOfThought.length - 1;
|
||||
const [collapsed, setCollapsed] = useState(props.completed);
|
||||
|
||||
useEffect(() => {
|
||||
if (props.completed) {
|
||||
setCollapsed(true);
|
||||
}
|
||||
}, [props.completed]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${!collapsed ? styles.trainOfThought + " shadow-sm" : ""}`}
|
||||
|
@ -410,6 +416,9 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
"inferred-queries": message.inferredQueries || [],
|
||||
},
|
||||
conversationId: props.conversationId,
|
||||
images: message.generatedImages,
|
||||
queryFiles: message.generatedFiles,
|
||||
excalidrawDiagram: message.generatedExcalidrawDiagram,
|
||||
turnId: messageTurnId,
|
||||
}}
|
||||
conversationId={props.conversationId}
|
||||
|
|
|
@ -77,6 +77,21 @@ div.imageWrapper img {
|
|||
border-radius: 8px;
|
||||
}
|
||||
|
||||
div.khoj div.imageWrapper img {
|
||||
height: 512px;
|
||||
}
|
||||
|
||||
div.khoj div.imageWrapper {
|
||||
flex: 1 1 auto;
|
||||
}
|
||||
|
||||
div.khoj div.imagesContainer {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
flex-direction: row;
|
||||
overflow-x: hidden;
|
||||
}
|
||||
|
||||
div.chatMessageContainer > img {
|
||||
width: auto;
|
||||
height: auto;
|
||||
|
@ -178,4 +193,9 @@ div.trainOfThoughtElement ul {
|
|||
div.youfullHistory {
|
||||
max-width: 90%;
|
||||
}
|
||||
|
||||
div.khoj div.imageWrapper img {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -163,6 +163,7 @@ export interface SingleChatMessage {
|
|||
conversationId: string;
|
||||
turnId?: string;
|
||||
queryFiles?: AttachedFileText[];
|
||||
excalidrawDiagram?: string;
|
||||
}
|
||||
|
||||
export interface StreamMessage {
|
||||
|
@ -180,6 +181,10 @@ export interface StreamMessage {
|
|||
inferredQueries?: string[];
|
||||
turnId?: string;
|
||||
queryFiles?: AttachedFileText[];
|
||||
excalidrawDiagram?: string;
|
||||
generatedFiles?: AttachedFileText[];
|
||||
generatedImages?: string[];
|
||||
generatedExcalidrawDiagram?: string;
|
||||
}
|
||||
|
||||
export interface ChatHistoryData {
|
||||
|
@ -264,6 +269,9 @@ interface ChatMessageProps {
|
|||
onDeleteMessage: (turnId?: string) => void;
|
||||
conversationId: string;
|
||||
turnId?: string;
|
||||
generatedImage?: string;
|
||||
excalidrawDiagram?: string;
|
||||
generatedFiles?: AttachedFileText[];
|
||||
}
|
||||
|
||||
interface TrainOfThoughtProps {
|
||||
|
@ -389,9 +397,8 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
|||
// Prepare initial message for rendering
|
||||
let message = props.chatMessage.message;
|
||||
|
||||
if (props.chatMessage.intent && props.chatMessage.intent.type == "excalidraw") {
|
||||
message = props.chatMessage.intent["inferred-queries"][0];
|
||||
setExcalidrawData(props.chatMessage.message);
|
||||
if (props.chatMessage.excalidrawDiagram) {
|
||||
setExcalidrawData(props.chatMessage.excalidrawDiagram);
|
||||
}
|
||||
|
||||
// Replace LaTeX delimiters with placeholders
|
||||
|
@ -401,27 +408,6 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
|||
.replace(/\\\[/g, "LEFTBRACKET")
|
||||
.replace(/\\\]/g, "RIGHTBRACKET");
|
||||
|
||||
const intentTypeHandlers = {
|
||||
"text-to-image": (msg: string) => `![generated image](data:image/png;base64,${msg})`,
|
||||
"text-to-image2": (msg: string) => `![generated image](${msg})`,
|
||||
"text-to-image-v3": (msg: string) =>
|
||||
`![generated image](data:image/webp;base64,${msg})`,
|
||||
excalidraw: (msg: string) => msg,
|
||||
};
|
||||
|
||||
// Handle intent-specific rendering
|
||||
if (props.chatMessage.intent) {
|
||||
const { type, "inferred-queries": inferredQueries } = props.chatMessage.intent;
|
||||
|
||||
if (type in intentTypeHandlers) {
|
||||
message = intentTypeHandlers[type as keyof typeof intentTypeHandlers](message);
|
||||
}
|
||||
|
||||
if (type.includes("text-to-image") && inferredQueries?.length > 0) {
|
||||
message += `\n\n${inferredQueries[0]}`;
|
||||
}
|
||||
}
|
||||
|
||||
// Replace file links with base64 data
|
||||
message = renderCodeGenImageInline(message, props.chatMessage.codeContext);
|
||||
|
||||
|
|
|
@ -303,17 +303,10 @@ class ConversationAdmin(unfold_admin.ModelAdmin):
|
|||
modified_log = conversation.conversation_log
|
||||
chat_log = modified_log.get("chat", [])
|
||||
for idx, log in enumerate(chat_log):
|
||||
if (
|
||||
log["by"] == "khoj"
|
||||
and log["intent"]
|
||||
and log["intent"]["type"]
|
||||
and (
|
||||
log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
|
||||
or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
|
||||
)
|
||||
):
|
||||
log["message"] = "inline image redacted for space"
|
||||
if log["by"] == "khoj" and log["images"]:
|
||||
log["images"] = ["inline image redacted for space"]
|
||||
chat_log[idx] = log
|
||||
|
||||
modified_log["chat"] = chat_log
|
||||
|
||||
writer.writerow(
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
# Made manually by sabaimran for use by Django 5.0.9 on 2024-12-01 16:59
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
# This script was written alongside when Pydantic validation was added to the Conversation conversation_log field.
|
||||
|
||||
|
||||
def migrate_generated_assets(apps, schema_editor):
|
||||
Conversation = apps.get_model("database", "Conversation")
|
||||
|
||||
# Process conversations in chunks
|
||||
for conversation in Conversation.objects.iterator():
|
||||
try:
|
||||
meta_log = conversation.conversation_log
|
||||
modified = False
|
||||
|
||||
for chat in meta_log.get("chat", []):
|
||||
intent_type = chat.get("intent", {}).get("type")
|
||||
|
||||
if intent_type and chat["by"] == "khoj":
|
||||
if intent_type and "text-to-image" in intent_type:
|
||||
# Migrate the generated image to the new format
|
||||
chat["images"] = [chat.get("message")]
|
||||
chat["message"] = chat["intent"]["inferred-queries"][0]
|
||||
modified = True
|
||||
|
||||
if intent_type and "excalidraw" in intent_type:
|
||||
# Migrate the generated excalidraw to the new format
|
||||
chat["excalidrawDiagram"] = chat.get("message")
|
||||
chat["message"] = chat["intent"]["inferred-queries"][0]
|
||||
modified = True
|
||||
|
||||
# Only save if changes were made
|
||||
if modified:
|
||||
conversation.conversation_log = meta_log
|
||||
conversation.save()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing conversation {conversation.id}: {str(e)}")
|
||||
continue
|
||||
|
||||
|
||||
def reverse_migration(apps, schema_editor):
|
||||
Conversation = apps.get_model("database", "Conversation")
|
||||
|
||||
# Process conversations in chunks
|
||||
for conversation in Conversation.objects.iterator():
|
||||
try:
|
||||
meta_log = conversation.conversation_log
|
||||
modified = False
|
||||
|
||||
for chat in meta_log.get("chat", []):
|
||||
intent_type = chat.get("intent", {}).get("type")
|
||||
|
||||
if intent_type and chat["by"] == "khoj":
|
||||
if intent_type and "text-to-image" in intent_type:
|
||||
# Migrate the generated image back to the old format
|
||||
chat["message"] = chat.get("images", [])[0]
|
||||
chat.pop("images", None)
|
||||
modified = True
|
||||
|
||||
if intent_type and "excalidraw" in intent_type:
|
||||
# Migrate the generated excalidraw back to the old format
|
||||
chat["message"] = chat.get("excalidrawDiagram")
|
||||
chat.pop("excalidrawDiagram", None)
|
||||
modified = True
|
||||
|
||||
# Only save if changes were made
|
||||
if modified:
|
||||
conversation.conversation_log = meta_log
|
||||
conversation.save()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing conversation {conversation.id}: {str(e)}")
|
||||
continue
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0074_alter_conversation_title"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(migrate_generated_assets, reverse_migration),
|
||||
]
|
|
@ -1,7 +1,9 @@
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from random import choice
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from django.contrib.auth.models import AbstractUser
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
|
@ -11,9 +13,109 @@ from django.db.models.signals import pre_save
|
|||
from django.dispatch import receiver
|
||||
from pgvector.django import VectorField
|
||||
from phonenumber_field.modelfields import PhoneNumberField
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from pydantic import Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseModel(models.Model):
|
||||
# Pydantic models for type Chat Message validation
|
||||
class Context(PydanticBaseModel):
|
||||
compiled: str
|
||||
file: str
|
||||
|
||||
|
||||
class CodeContextFile(PydanticBaseModel):
|
||||
filename: str
|
||||
b64_data: str
|
||||
|
||||
|
||||
class CodeContextResult(PydanticBaseModel):
|
||||
success: bool
|
||||
output_files: List[CodeContextFile]
|
||||
std_out: str
|
||||
std_err: str
|
||||
code_runtime: int
|
||||
|
||||
|
||||
class CodeContextData(PydanticBaseModel):
|
||||
code: str
|
||||
result: Optional[CodeContextResult] = None
|
||||
|
||||
|
||||
class WebPage(PydanticBaseModel):
|
||||
link: str
|
||||
query: Optional[str] = None
|
||||
snippet: str
|
||||
|
||||
|
||||
class AnswerBox(PydanticBaseModel):
|
||||
link: Optional[str] = None
|
||||
snippet: Optional[str] = None
|
||||
title: str
|
||||
snippetHighlighted: Optional[List[str]] = None
|
||||
|
||||
|
||||
class PeopleAlsoAsk(PydanticBaseModel):
|
||||
link: Optional[str] = None
|
||||
question: Optional[str] = None
|
||||
snippet: Optional[str] = None
|
||||
title: str
|
||||
|
||||
|
||||
class KnowledgeGraph(PydanticBaseModel):
|
||||
attributes: Optional[Dict[str, str]] = None
|
||||
description: Optional[str] = None
|
||||
descriptionLink: Optional[str] = None
|
||||
descriptionSource: Optional[str] = None
|
||||
imageUrl: Optional[str] = None
|
||||
title: str
|
||||
type: Optional[str] = None
|
||||
|
||||
|
||||
class OrganicContext(PydanticBaseModel):
|
||||
snippet: str
|
||||
title: str
|
||||
link: str
|
||||
|
||||
|
||||
class OnlineContext(PydanticBaseModel):
|
||||
webpages: Optional[Union[WebPage, List[WebPage]]] = None
|
||||
answerBox: Optional[AnswerBox] = None
|
||||
peopleAlsoAsk: Optional[List[PeopleAlsoAsk]] = None
|
||||
knowledgeGraph: Optional[KnowledgeGraph] = None
|
||||
organicContext: Optional[List[OrganicContext]] = None
|
||||
|
||||
|
||||
class Intent(PydanticBaseModel):
|
||||
type: str
|
||||
query: str
|
||||
memory_type: str = Field(alias="memory-type")
|
||||
inferred_queries: Optional[List[str]] = Field(default=None, alias="inferred-queries")
|
||||
|
||||
|
||||
class TrainOfThought(PydanticBaseModel):
|
||||
type: str
|
||||
data: str
|
||||
|
||||
|
||||
class ChatMessage(PydanticBaseModel):
|
||||
message: str
|
||||
trainOfThought: List[TrainOfThought] = []
|
||||
context: List[Context] = []
|
||||
onlineContext: Dict[str, OnlineContext] = {}
|
||||
codeContext: Dict[str, CodeContextData] = {}
|
||||
created: str
|
||||
images: Optional[List[str]] = None
|
||||
queryFiles: Optional[List[Dict]] = None
|
||||
excalidrawDiagram: Optional[List[Dict]] = None
|
||||
by: str
|
||||
turnId: Optional[str] = None
|
||||
intent: Optional[Intent] = None
|
||||
automationId: Optional[str] = None
|
||||
|
||||
|
||||
class DbBaseModel(models.Model):
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
updated_at = models.DateTimeField(auto_now=True)
|
||||
|
||||
|
@ -21,7 +123,7 @@ class BaseModel(models.Model):
|
|||
abstract = True
|
||||
|
||||
|
||||
class ClientApplication(BaseModel):
|
||||
class ClientApplication(DbBaseModel):
|
||||
name = models.CharField(max_length=200)
|
||||
client_id = models.CharField(max_length=200)
|
||||
client_secret = models.CharField(max_length=200)
|
||||
|
@ -67,7 +169,7 @@ class KhojApiUser(models.Model):
|
|||
accessed_at = models.DateTimeField(null=True, default=None)
|
||||
|
||||
|
||||
class Subscription(BaseModel):
|
||||
class Subscription(DbBaseModel):
|
||||
class Type(models.TextChoices):
|
||||
TRIAL = "trial"
|
||||
STANDARD = "standard"
|
||||
|
@ -79,13 +181,13 @@ class Subscription(BaseModel):
|
|||
enabled_trial_at = models.DateTimeField(null=True, default=None, blank=True)
|
||||
|
||||
|
||||
class OpenAIProcessorConversationConfig(BaseModel):
|
||||
class OpenAIProcessorConversationConfig(DbBaseModel):
|
||||
name = models.CharField(max_length=200)
|
||||
api_key = models.CharField(max_length=200)
|
||||
api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True)
|
||||
|
||||
|
||||
class ChatModelOptions(BaseModel):
|
||||
class ChatModelOptions(DbBaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
OFFLINE = "offline"
|
||||
|
@ -103,12 +205,12 @@ class ChatModelOptions(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class VoiceModelOption(BaseModel):
|
||||
class VoiceModelOption(DbBaseModel):
|
||||
model_id = models.CharField(max_length=200)
|
||||
name = models.CharField(max_length=200)
|
||||
|
||||
|
||||
class Agent(BaseModel):
|
||||
class Agent(DbBaseModel):
|
||||
class StyleColorTypes(models.TextChoices):
|
||||
BLUE = "blue"
|
||||
GREEN = "green"
|
||||
|
@ -208,7 +310,7 @@ class Agent(BaseModel):
|
|||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class ProcessLock(BaseModel):
|
||||
class ProcessLock(DbBaseModel):
|
||||
class Operation(models.TextChoices):
|
||||
INDEX_CONTENT = "index_content"
|
||||
SCHEDULED_JOB = "scheduled_job"
|
||||
|
@ -231,24 +333,24 @@ def verify_agent(sender, instance, **kwargs):
|
|||
raise ValidationError(f"A private Agent with the name {instance.name} already exists.")
|
||||
|
||||
|
||||
class NotionConfig(BaseModel):
|
||||
class NotionConfig(DbBaseModel):
|
||||
token = models.CharField(max_length=200)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class GithubConfig(BaseModel):
|
||||
class GithubConfig(DbBaseModel):
|
||||
pat_token = models.CharField(max_length=200)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class GithubRepoConfig(BaseModel):
|
||||
class GithubRepoConfig(DbBaseModel):
|
||||
name = models.CharField(max_length=200)
|
||||
owner = models.CharField(max_length=200)
|
||||
branch = models.CharField(max_length=200)
|
||||
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
|
||||
|
||||
|
||||
class WebScraper(BaseModel):
|
||||
class WebScraper(DbBaseModel):
|
||||
class WebScraperType(models.TextChoices):
|
||||
FIRECRAWL = "Firecrawl"
|
||||
OLOSTEP = "Olostep"
|
||||
|
@ -321,7 +423,7 @@ class WebScraper(BaseModel):
|
|||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class ServerChatSettings(BaseModel):
|
||||
class ServerChatSettings(DbBaseModel):
|
||||
chat_default = models.ForeignKey(
|
||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
|
||||
)
|
||||
|
@ -333,35 +435,35 @@ class ServerChatSettings(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class LocalOrgConfig(BaseModel):
|
||||
class LocalOrgConfig(DbBaseModel):
|
||||
input_files = models.JSONField(default=list, null=True)
|
||||
input_filter = models.JSONField(default=list, null=True)
|
||||
index_heading_entries = models.BooleanField(default=False)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class LocalMarkdownConfig(BaseModel):
|
||||
class LocalMarkdownConfig(DbBaseModel):
|
||||
input_files = models.JSONField(default=list, null=True)
|
||||
input_filter = models.JSONField(default=list, null=True)
|
||||
index_heading_entries = models.BooleanField(default=False)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class LocalPdfConfig(BaseModel):
|
||||
class LocalPdfConfig(DbBaseModel):
|
||||
input_files = models.JSONField(default=list, null=True)
|
||||
input_filter = models.JSONField(default=list, null=True)
|
||||
index_heading_entries = models.BooleanField(default=False)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class LocalPlaintextConfig(BaseModel):
|
||||
class LocalPlaintextConfig(DbBaseModel):
|
||||
input_files = models.JSONField(default=list, null=True)
|
||||
input_filter = models.JSONField(default=list, null=True)
|
||||
index_heading_entries = models.BooleanField(default=False)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class SearchModelConfig(BaseModel):
|
||||
class SearchModelConfig(DbBaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
TEXT = "text"
|
||||
|
||||
|
@ -393,7 +495,7 @@ class SearchModelConfig(BaseModel):
|
|||
bi_encoder_confidence_threshold = models.FloatField(default=0.18)
|
||||
|
||||
|
||||
class TextToImageModelConfig(BaseModel):
|
||||
class TextToImageModelConfig(DbBaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
STABILITYAI = "stability-ai"
|
||||
|
@ -430,7 +532,7 @@ class TextToImageModelConfig(BaseModel):
|
|||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class SpeechToTextModelOptions(BaseModel):
|
||||
class SpeechToTextModelOptions(DbBaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
OFFLINE = "offline"
|
||||
|
@ -439,22 +541,22 @@ class SpeechToTextModelOptions(BaseModel):
|
|||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
|
||||
|
||||
class UserConversationConfig(BaseModel):
|
||||
class UserConversationConfig(DbBaseModel):
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class UserVoiceModelConfig(BaseModel):
|
||||
class UserVoiceModelConfig(DbBaseModel):
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class UserTextToImageModelConfig(BaseModel):
|
||||
class UserTextToImageModelConfig(DbBaseModel):
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
class Conversation(DbBaseModel):
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
conversation_log = models.JSONField(default=dict)
|
||||
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
@ -468,8 +570,39 @@ class Conversation(BaseModel):
|
|||
file_filters = models.JSONField(default=list)
|
||||
id = models.UUIDField(default=uuid.uuid4, editable=False, unique=True, primary_key=True, db_index=True)
|
||||
|
||||
def clean(self):
|
||||
# Validate conversation_log structure
|
||||
try:
|
||||
messages = self.conversation_log.get("chat", [])
|
||||
for msg in messages:
|
||||
ChatMessage.model_validate(msg)
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Invalid conversation_log format: {str(e)}")
|
||||
|
||||
class PublicConversation(BaseModel):
|
||||
def save(self, *args, **kwargs):
|
||||
self.clean()
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def messages(self) -> List[ChatMessage]:
|
||||
"""Type-hinted accessor for conversation messages"""
|
||||
validated_messages = []
|
||||
for msg in self.conversation_log.get("chat", []):
|
||||
try:
|
||||
# Clean up inferred queries if they contain None
|
||||
if msg.get("intent") and msg["intent"].get("inferred-queries"):
|
||||
msg["intent"]["inferred-queries"] = [
|
||||
q for q in msg["intent"]["inferred-queries"] if q is not None and isinstance(q, str)
|
||||
]
|
||||
msg["message"] = str(msg.get("message", ""))
|
||||
validated_messages.append(ChatMessage.model_validate(msg))
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Skipping invalid message in conversation: {e}")
|
||||
continue
|
||||
return validated_messages
|
||||
|
||||
|
||||
class PublicConversation(DbBaseModel):
|
||||
source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
conversation_log = models.JSONField(default=dict)
|
||||
slug = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
|
@ -499,12 +632,12 @@ def verify_public_conversation(sender, instance, **kwargs):
|
|||
instance.slug = slug
|
||||
|
||||
|
||||
class ReflectiveQuestion(BaseModel):
|
||||
class ReflectiveQuestion(DbBaseModel):
|
||||
question = models.CharField(max_length=500)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class Entry(BaseModel):
|
||||
class Entry(DbBaseModel):
|
||||
class EntryType(models.TextChoices):
|
||||
IMAGE = "image"
|
||||
PDF = "pdf"
|
||||
|
@ -541,7 +674,7 @@ class Entry(BaseModel):
|
|||
raise ValidationError("An Entry cannot be associated with both a user and an agent.")
|
||||
|
||||
|
||||
class FileObject(BaseModel):
|
||||
class FileObject(DbBaseModel):
|
||||
# Same as Entry but raw will be a much larger string
|
||||
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||
raw_text = models.TextField()
|
||||
|
@ -549,7 +682,7 @@ class FileObject(BaseModel):
|
|||
agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class EntryDates(BaseModel):
|
||||
class EntryDates(DbBaseModel):
|
||||
date = models.DateField()
|
||||
entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")
|
||||
|
||||
|
@ -559,12 +692,12 @@ class EntryDates(BaseModel):
|
|||
]
|
||||
|
||||
|
||||
class UserRequests(BaseModel):
|
||||
class UserRequests(DbBaseModel):
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
slug = models.CharField(max_length=200)
|
||||
|
||||
|
||||
class DataStore(BaseModel):
|
||||
class DataStore(DbBaseModel):
|
||||
key = models.CharField(max_length=200, unique=True)
|
||||
value = models.JSONField(default=dict)
|
||||
private = models.BooleanField(default=False)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
|
@ -23,7 +23,7 @@ from khoj.utils.helpers import (
|
|||
is_none_or_empty,
|
||||
truncate_code_context,
|
||||
)
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.rawconfig import FileAttachment, LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -55,7 +55,7 @@ def extract_questions_anthropic(
|
|||
[
|
||||
f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
|
||||
for chat in conversation_log.get("chat", [])[-4:]
|
||||
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type")
|
||||
if chat["by"] == "khoj"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -157,6 +157,10 @@ def converse_anthropic(
|
|||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
query_files: str = None,
|
||||
generated_images: Optional[list[str]] = None,
|
||||
generated_files: List[FileAttachment] = None,
|
||||
generated_excalidraw_diagram: Optional[str] = None,
|
||||
program_execution_context: Optional[List[str]] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
|
@ -217,6 +221,10 @@ def converse_anthropic(
|
|||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
||||
query_files=query_files,
|
||||
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||
generated_files=generated_files,
|
||||
generated_images=generated_images,
|
||||
program_execution_context=program_execution_context,
|
||||
)
|
||||
|
||||
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
|
@ -23,7 +23,7 @@ from khoj.utils.helpers import (
|
|||
is_none_or_empty,
|
||||
truncate_code_context,
|
||||
)
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.rawconfig import FileAttachment, LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -56,7 +56,7 @@ def extract_questions_gemini(
|
|||
[
|
||||
f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
|
||||
for chat in conversation_log.get("chat", [])[-4:]
|
||||
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type")
|
||||
if chat["by"] == "khoj"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -167,6 +167,10 @@ def converse_gemini(
|
|||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
query_files: str = None,
|
||||
generated_images: Optional[list[str]] = None,
|
||||
generated_files: List[FileAttachment] = None,
|
||||
generated_excalidraw_diagram: Optional[str] = None,
|
||||
program_execution_context: List[str] = None,
|
||||
tracer={},
|
||||
):
|
||||
"""
|
||||
|
@ -228,6 +232,10 @@ def converse_gemini(
|
|||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||
query_files=query_files,
|
||||
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||
generated_files=generated_files,
|
||||
generated_images=generated_images,
|
||||
program_execution_context=program_execution_context,
|
||||
)
|
||||
|
||||
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
|
|
|
@ -28,7 +28,7 @@ from khoj.utils.helpers import (
|
|||
is_promptrace_enabled,
|
||||
truncate_code_context,
|
||||
)
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.rawconfig import FileAttachment, LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -69,7 +69,7 @@ def extract_questions_offline(
|
|||
|
||||
if use_history:
|
||||
for chat in conversation_log.get("chat", [])[-4:]:
|
||||
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type"):
|
||||
if chat["by"] == "khoj":
|
||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||
chat_history += f"Khoj: {chat['message']}\n\n"
|
||||
|
||||
|
@ -164,6 +164,8 @@ def converse_offline(
|
|||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
query_files: str = None,
|
||||
generated_files: List[FileAttachment] = None,
|
||||
additional_context: List[str] = None,
|
||||
tracer: dict = {},
|
||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||
"""
|
||||
|
@ -231,6 +233,8 @@ def converse_offline(
|
|||
tokenizer_name=tokenizer_name,
|
||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||
query_files=query_files,
|
||||
generated_files=generated_files,
|
||||
program_execution_context=additional_context,
|
||||
)
|
||||
|
||||
logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}")
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
|
@ -22,7 +22,7 @@ from khoj.utils.helpers import (
|
|||
is_none_or_empty,
|
||||
truncate_code_context,
|
||||
)
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.rawconfig import FileAttachment, LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -157,6 +157,10 @@ def converse(
|
|||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
query_files: str = None,
|
||||
generated_images: Optional[list[str]] = None,
|
||||
generated_files: List[FileAttachment] = None,
|
||||
generated_excalidraw_diagram: Optional[str] = None,
|
||||
program_execution_context: List[str] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
|
@ -219,6 +223,10 @@ def converse(
|
|||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
query_files=query_files,
|
||||
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||
generated_files=generated_files,
|
||||
generated_images=generated_images,
|
||||
program_execution_context=program_execution_context,
|
||||
)
|
||||
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
||||
|
||||
|
|
|
@ -178,6 +178,18 @@ Improved Prompt:
|
|||
""".strip()
|
||||
)
|
||||
|
||||
generated_image_attachment = PromptTemplate.from_template(
|
||||
f"""
|
||||
Here is the image you generated based on my query. You can follow-up with a general response to my query. Limit to 1-2 sentences.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
generated_diagram_attachment = PromptTemplate.from_template(
|
||||
f"""
|
||||
I've successfully created a diagram based on the user's query. The diagram will automatically be shared with the user. I can follow-up with a general response or summary. Limit to 1-2 sentences.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
## Diagram Generation
|
||||
## --
|
||||
|
||||
|
@ -1029,6 +1041,12 @@ A:
|
|||
""".strip()
|
||||
)
|
||||
|
||||
additional_program_context = PromptTemplate.from_template(
|
||||
"""
|
||||
Here are some additional results from the query execution:
|
||||
{context}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
personality_prompt_safety_expert_lax = PromptTemplate.from_template(
|
||||
"""
|
||||
|
|
|
@ -154,7 +154,7 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
|
|||
chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
|
||||
|
||||
chat_history += f"{agent_name}: {chat['message']}\n\n"
|
||||
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
||||
elif chat["by"] == "khoj" and chat.get("images"):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
||||
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
||||
|
@ -213,6 +213,7 @@ class ChatEvent(Enum):
|
|||
END_LLM_RESPONSE = "end_llm_response"
|
||||
MESSAGE = "message"
|
||||
REFERENCES = "references"
|
||||
GENERATED_ASSETS = "generated_assets"
|
||||
STATUS = "status"
|
||||
METADATA = "metadata"
|
||||
USAGE = "usage"
|
||||
|
@ -225,7 +226,6 @@ def message_to_log(
|
|||
user_message_metadata={},
|
||||
khoj_message_metadata={},
|
||||
conversation_log=[],
|
||||
train_of_thought=[],
|
||||
):
|
||||
"""Create json logs from messages, metadata for conversation log"""
|
||||
default_khoj_message_metadata = {
|
||||
|
@ -234,6 +234,10 @@ def message_to_log(
|
|||
}
|
||||
khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# Filter out any fields that are set to None
|
||||
user_message_metadata = {k: v for k, v in user_message_metadata.items() if v is not None}
|
||||
khoj_message_metadata = {k: v for k, v in khoj_message_metadata.items() if v is not None}
|
||||
|
||||
# Create json log from Human's message
|
||||
human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
|
||||
|
||||
|
@ -261,31 +265,41 @@ def save_to_conversation_log(
|
|||
automation_id: str = None,
|
||||
query_images: List[str] = None,
|
||||
raw_query_files: List[FileAttachment] = [],
|
||||
generated_images: List[str] = [],
|
||||
raw_generated_files: List[FileAttachment] = [],
|
||||
generated_excalidraw_diagram: str = None,
|
||||
train_of_thought: List[Any] = [],
|
||||
tracer: Dict[str, Any] = {},
|
||||
):
|
||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
turn_id = tracer.get("mid") or str(uuid.uuid4())
|
||||
|
||||
user_message_metadata = {"created": user_message_time, "images": query_images, "turnId": turn_id}
|
||||
|
||||
if raw_query_files and len(raw_query_files) > 0:
|
||||
user_message_metadata["queryFiles"] = [file.model_dump(mode="json") for file in raw_query_files]
|
||||
|
||||
khoj_message_metadata = {
|
||||
"context": compiled_references,
|
||||
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
|
||||
"onlineContext": online_results,
|
||||
"codeContext": code_results,
|
||||
"automationId": automation_id,
|
||||
"trainOfThought": train_of_thought,
|
||||
"turnId": turn_id,
|
||||
"images": generated_images,
|
||||
"queryFiles": [file.model_dump(mode="json") for file in raw_generated_files],
|
||||
}
|
||||
|
||||
if generated_excalidraw_diagram:
|
||||
khoj_message_metadata["excalidrawDiagram"] = generated_excalidraw_diagram
|
||||
|
||||
updated_conversation = message_to_log(
|
||||
user_message=q,
|
||||
chat_response=chat_response,
|
||||
user_message_metadata={
|
||||
"created": user_message_time,
|
||||
"images": query_images,
|
||||
"turnId": turn_id,
|
||||
"queryFiles": [file.model_dump(mode="json") for file in raw_query_files],
|
||||
},
|
||||
khoj_message_metadata={
|
||||
"context": compiled_references,
|
||||
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
|
||||
"onlineContext": online_results,
|
||||
"codeContext": code_results,
|
||||
"automationId": automation_id,
|
||||
"trainOfThought": train_of_thought,
|
||||
"turnId": turn_id,
|
||||
},
|
||||
user_message_metadata=user_message_metadata,
|
||||
khoj_message_metadata=khoj_message_metadata,
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
train_of_thought=train_of_thought,
|
||||
)
|
||||
ConversationAdapters.save_conversation(
|
||||
user,
|
||||
|
@ -303,13 +317,13 @@ def save_to_conversation_log(
|
|||
Saved Conversation Turn
|
||||
You ({user.username}): "{q}"
|
||||
|
||||
Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}"
|
||||
Khoj: "{chat_response}"
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def construct_structured_message(
|
||||
message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str
|
||||
message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str = None
|
||||
):
|
||||
"""
|
||||
Format messages into appropriate multimedia format for supported chat model types
|
||||
|
@ -327,7 +341,8 @@ def construct_structured_message(
|
|||
constructed_messages.append({"type": "text", "text": attached_file_context})
|
||||
if vision_enabled and images:
|
||||
for image in images:
|
||||
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
|
||||
if image.startswith("https://"):
|
||||
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
|
||||
return constructed_messages
|
||||
|
||||
if not is_none_or_empty(attached_file_context):
|
||||
|
@ -365,6 +380,10 @@ def generate_chatml_messages_with_context(
|
|||
model_type="",
|
||||
context_message="",
|
||||
query_files: str = None,
|
||||
generated_images: Optional[list[str]] = None,
|
||||
generated_files: List[FileAttachment] = None,
|
||||
generated_excalidraw_diagram: str = None,
|
||||
program_execution_context: List[str] = [],
|
||||
):
|
||||
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
||||
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||
|
@ -384,6 +403,7 @@ def generate_chatml_messages_with_context(
|
|||
message_attached_files = ""
|
||||
|
||||
chat_message = chat.get("message")
|
||||
role = "user" if chat["by"] == "you" else "assistant"
|
||||
|
||||
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
||||
chat_message = chat["intent"].get("inferred-queries")[0]
|
||||
|
@ -404,7 +424,7 @@ def generate_chatml_messages_with_context(
|
|||
query_files_dict[file["name"]] = file["content"]
|
||||
|
||||
message_attached_files = gather_raw_query_files(query_files_dict)
|
||||
chatml_messages.append(ChatMessage(content=message_attached_files, role="user"))
|
||||
chatml_messages.append(ChatMessage(content=message_attached_files, role=role))
|
||||
|
||||
if not is_none_or_empty(chat.get("onlineContext")):
|
||||
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
|
||||
|
@ -413,10 +433,20 @@ def generate_chatml_messages_with_context(
|
|||
reconstructed_context_message = ChatMessage(content=message_context, role="user")
|
||||
chatml_messages.insert(0, reconstructed_context_message)
|
||||
|
||||
role = "user" if chat["by"] == "you" else "assistant"
|
||||
message_content = construct_structured_message(
|
||||
chat_message, chat.get("images"), model_type, vision_enabled, attached_file_context=query_files
|
||||
)
|
||||
if chat.get("images"):
|
||||
if role == "assistant":
|
||||
# Issue: the assistant role cannot accept an image as a message content, so send it in a separate user message.
|
||||
file_attachment_message = construct_structured_message(
|
||||
message=prompts.generated_image_attachment.format(),
|
||||
images=chat.get("images"),
|
||||
model_type=model_type,
|
||||
vision_enabled=vision_enabled,
|
||||
)
|
||||
chatml_messages.append(ChatMessage(content=file_attachment_message, role="user"))
|
||||
else:
|
||||
message_content = construct_structured_message(
|
||||
chat_message, chat.get("images"), model_type, vision_enabled
|
||||
)
|
||||
|
||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||
chatml_messages.insert(0, reconstructed_message)
|
||||
|
@ -425,6 +455,7 @@ def generate_chatml_messages_with_context(
|
|||
break
|
||||
|
||||
messages = []
|
||||
|
||||
if not is_none_or_empty(user_message):
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
|
@ -437,6 +468,31 @@ def generate_chatml_messages_with_context(
|
|||
if not is_none_or_empty(context_message):
|
||||
messages.append(ChatMessage(content=context_message, role="user"))
|
||||
|
||||
if generated_images:
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
content=construct_structured_message(
|
||||
prompts.generated_image_attachment.format(), generated_images, model_type, vision_enabled
|
||||
),
|
||||
role="user",
|
||||
)
|
||||
)
|
||||
|
||||
if generated_files:
|
||||
message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
|
||||
messages.append(ChatMessage(content=message_attached_files, role="assistant"))
|
||||
|
||||
if generated_excalidraw_diagram:
|
||||
messages.append(ChatMessage(content=prompts.generated_diagram_attachment.format(), role="assistant"))
|
||||
|
||||
if program_execution_context:
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
content=prompts.additional_program_context.format(context="\n".join(program_execution_context)),
|
||||
role="assistant",
|
||||
)
|
||||
)
|
||||
|
||||
if len(chatml_messages) > 0:
|
||||
messages += chatml_messages
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
|
|||
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
|
||||
from khoj.routers.storage import upload_image
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import ImageIntentType, convert_image_to_webp, timer
|
||||
from khoj.utils.helpers import convert_image_to_webp, timer
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -34,14 +34,13 @@ async def text_to_image(
|
|||
status_code = 200
|
||||
image = None
|
||||
image_url = None
|
||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||
|
||||
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
|
||||
if not text_to_image_config:
|
||||
# If the user has not configured a text to image model, return an unsupported on server error
|
||||
status_code = 501
|
||||
message = "Failed to generate image. Setup image generation on the server."
|
||||
yield image_url or image, status_code, message, intent_type.value
|
||||
yield image_url or image, status_code, message
|
||||
return
|
||||
|
||||
text2image_model = text_to_image_config.model_name
|
||||
|
@ -50,8 +49,8 @@ async def text_to_image(
|
|||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
|
||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||
chat_history += f"A: {chat['message']}\n"
|
||||
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
|
||||
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
|
||||
elif chat["by"] == "khoj" and chat.get("images"):
|
||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
|
||||
|
||||
if send_status_func:
|
||||
|
@ -92,31 +91,29 @@ async def text_to_image(
|
|||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||
status_code = e.status_code # type: ignore
|
||||
message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore
|
||||
yield image_url or image, status_code, message, intent_type.value
|
||||
yield image_url or image, status_code, message
|
||||
return
|
||||
else:
|
||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||
message = f"Image generation failed using OpenAI" # type: ignore
|
||||
status_code = e.status_code # type: ignore
|
||||
yield image_url or image, status_code, message, intent_type.value
|
||||
yield image_url or image, status_code, message
|
||||
return
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error."
|
||||
status_code = 502
|
||||
yield image_url or image, status_code, message, intent_type.value
|
||||
yield image_url or image, status_code, message
|
||||
return
|
||||
|
||||
# Decide how to store the generated image
|
||||
with timer("Upload image to S3", logger):
|
||||
image_url = upload_image(webp_image_bytes, user.uuid)
|
||||
if image_url:
|
||||
intent_type = ImageIntentType.TEXT_TO_IMAGE2
|
||||
else:
|
||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||
|
||||
if not image_url:
|
||||
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
||||
|
||||
yield image_url or image, status_code, image_prompt, intent_type.value
|
||||
yield image_url or image, status_code, image_prompt
|
||||
|
||||
|
||||
def generate_image_with_openai(
|
||||
|
|
|
@ -77,6 +77,7 @@ from khoj.utils.helpers import (
|
|||
)
|
||||
from khoj.utils.rawconfig import (
|
||||
ChatRequestBody,
|
||||
FileAttachment,
|
||||
FileFilterRequest,
|
||||
FilesFilterRequest,
|
||||
LocationData,
|
||||
|
@ -770,6 +771,11 @@ async def chat(
|
|||
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
||||
attached_file_context = gather_raw_query_files(query_files)
|
||||
|
||||
generated_images: List[str] = []
|
||||
generated_files: List[FileAttachment] = []
|
||||
generated_excalidraw_diagram: str = None
|
||||
program_execution_context: List[str] = []
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||
chosen_io = await aget_data_sources_and_output_format(
|
||||
q,
|
||||
|
@ -875,21 +881,17 @@ async def chat(
|
|||
async for result in send_llm_response(response, tracer.get("usage")):
|
||||
yield result
|
||||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
response_log,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
intent_type="summarize",
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
query_images=uploaded_images,
|
||||
train_of_thought=train_of_thought,
|
||||
raw_query_files=raw_query_files,
|
||||
tracer=tracer,
|
||||
summarized_document = FileAttachment(
|
||||
name="Summarized Document",
|
||||
content=response_log,
|
||||
type="text/plain",
|
||||
size=len(response_log.encode("utf-8")),
|
||||
)
|
||||
return
|
||||
|
||||
async for result in send_event(ChatEvent.GENERATED_ASSETS, {"files": [summarized_document.model_dump()]}):
|
||||
yield result
|
||||
|
||||
generated_files.append(summarized_document)
|
||||
|
||||
custom_filters = []
|
||||
if conversation_commands == [ConversationCommand.Help]:
|
||||
|
@ -1078,6 +1080,7 @@ async def chat(
|
|||
async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"):
|
||||
yield result
|
||||
except ValueError as e:
|
||||
program_execution_context.append(f"Failed to run code")
|
||||
logger.warning(
|
||||
f"Failed to use code tool: {e}. Attempting to respond without code results",
|
||||
exc_info=True,
|
||||
|
@ -1115,51 +1118,28 @@ async def chat(
|
|||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
generated_image, status_code, improved_image_prompt, intent_type = result
|
||||
generated_image, status_code, improved_image_prompt = result
|
||||
|
||||
inferred_queries.append(improved_image_prompt)
|
||||
if generated_image is None or status_code != 200:
|
||||
content_obj = {
|
||||
"content-type": "application/json",
|
||||
"intentType": intent_type,
|
||||
"detail": improved_image_prompt,
|
||||
"image": None,
|
||||
}
|
||||
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
||||
program_execution_context.append(f"Failed to generate image with {improved_image_prompt}")
|
||||
async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"):
|
||||
yield result
|
||||
return
|
||||
else:
|
||||
generated_images.append(generated_image)
|
||||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
generated_image,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
intent_type=intent_type,
|
||||
inferred_queries=[improved_image_prompt],
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
query_images=uploaded_images,
|
||||
train_of_thought=train_of_thought,
|
||||
raw_query_files=raw_query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
content_obj = {
|
||||
"intentType": intent_type,
|
||||
"inferredQueries": [improved_image_prompt],
|
||||
"image": generated_image,
|
||||
}
|
||||
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
||||
yield result
|
||||
return
|
||||
async for result in send_event(
|
||||
ChatEvent.GENERATED_ASSETS,
|
||||
{
|
||||
"images": [generated_image],
|
||||
},
|
||||
):
|
||||
yield result
|
||||
|
||||
if ConversationCommand.Diagram in conversation_commands:
|
||||
async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
|
||||
yield result
|
||||
|
||||
intent_type = "excalidraw"
|
||||
inferred_queries = []
|
||||
diagram_description = ""
|
||||
|
||||
|
@ -1183,62 +1163,29 @@ async def chat(
|
|||
if better_diagram_description_prompt and excalidraw_diagram_description:
|
||||
inferred_queries.append(better_diagram_description_prompt)
|
||||
diagram_description = excalidraw_diagram_description
|
||||
|
||||
generated_excalidraw_diagram = diagram_description
|
||||
|
||||
async for result in send_event(
|
||||
ChatEvent.GENERATED_ASSETS,
|
||||
{
|
||||
"excalidrawDiagram": excalidraw_diagram_description,
|
||||
},
|
||||
):
|
||||
yield result
|
||||
else:
|
||||
error_message = "Failed to generate diagram. Please try again later."
|
||||
async for result in send_llm_response(error_message, tracer.get("usage")):
|
||||
yield result
|
||||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
error_message,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
inferred_queries=[better_diagram_description_prompt],
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
query_images=uploaded_images,
|
||||
train_of_thought=train_of_thought,
|
||||
raw_query_files=raw_query_files,
|
||||
tracer=tracer,
|
||||
program_execution_context.append(
|
||||
f"AI attempted to programmatically generate a diagram but failed due to a program issue. Generally, it is able to do so, but encountered a system issue this time. AI can suggest text description or rendering of the diagram or user can try again with a simpler prompt."
|
||||
)
|
||||
return
|
||||
|
||||
content_obj = {
|
||||
"intentType": intent_type,
|
||||
"inferredQueries": inferred_queries,
|
||||
"image": diagram_description,
|
||||
}
|
||||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
excalidraw_diagram_description,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
intent_type="excalidraw",
|
||||
inferred_queries=[better_diagram_description_prompt],
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
query_images=uploaded_images,
|
||||
train_of_thought=train_of_thought,
|
||||
raw_query_files=raw_query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
||||
yield result
|
||||
return
|
||||
async for result in send_event(ChatEvent.STATUS, error_message):
|
||||
yield result
|
||||
|
||||
## Generate Text Output
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
|
||||
yield result
|
||||
|
||||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
|
@ -1258,6 +1205,10 @@ async def chat(
|
|||
train_of_thought,
|
||||
attached_file_context,
|
||||
raw_query_files,
|
||||
generated_images,
|
||||
generated_files,
|
||||
generated_excalidraw_diagram,
|
||||
program_execution_context,
|
||||
tracer,
|
||||
)
|
||||
|
||||
|
|
|
@ -1185,6 +1185,10 @@ def generate_chat_response(
|
|||
train_of_thought: List[Any] = [],
|
||||
query_files: str = None,
|
||||
raw_query_files: List[FileAttachment] = None,
|
||||
generated_images: List[str] = None,
|
||||
raw_generated_files: List[FileAttachment] = [],
|
||||
generated_excalidraw_diagram: str = None,
|
||||
program_execution_context: List[str] = [],
|
||||
tracer: dict = {},
|
||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||
# Initialize Variables
|
||||
|
@ -1208,6 +1212,9 @@ def generate_chat_response(
|
|||
query_images=query_images,
|
||||
train_of_thought=train_of_thought,
|
||||
raw_query_files=raw_query_files,
|
||||
generated_images=generated_images,
|
||||
raw_generated_files=raw_generated_files,
|
||||
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
@ -1243,6 +1250,7 @@ def generate_chat_response(
|
|||
user_name=user_name,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
generated_files=raw_generated_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
@ -1269,6 +1277,10 @@ def generate_chat_response(
|
|||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
query_files=query_files,
|
||||
generated_files=raw_generated_files,
|
||||
generated_images=generated_images,
|
||||
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||
program_execution_context=program_execution_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
@ -1292,6 +1304,10 @@ def generate_chat_response(
|
|||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
query_files=query_files,
|
||||
generated_files=raw_generated_files,
|
||||
generated_images=generated_images,
|
||||
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||
program_execution_context=program_execution_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
|
@ -1314,6 +1330,10 @@ def generate_chat_response(
|
|||
query_images=query_images,
|
||||
vision_available=vision_available,
|
||||
query_files=query_files,
|
||||
generated_files=raw_generated_files,
|
||||
generated_images=generated_images,
|
||||
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||
program_execution_context=program_execution_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
@ -1785,6 +1805,9 @@ class MessageProcessor:
|
|||
self.references = {}
|
||||
self.usage = {}
|
||||
self.raw_response = ""
|
||||
self.generated_images = []
|
||||
self.generated_files = []
|
||||
self.generated_excalidraw_diagram = []
|
||||
|
||||
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
|
||||
if raw_chunk.startswith("{") and raw_chunk.endswith("}"):
|
||||
|
@ -1823,6 +1846,16 @@ class MessageProcessor:
|
|||
self.raw_response += chunk_data
|
||||
else:
|
||||
self.raw_response += chunk_data
|
||||
elif chunk_type == ChatEvent.GENERATED_ASSETS:
|
||||
chunk_data = chunk["data"]
|
||||
if isinstance(chunk_data, dict):
|
||||
for key in chunk_data:
|
||||
if key == "images":
|
||||
self.generated_images = chunk_data[key]
|
||||
elif key == "files":
|
||||
self.generated_files = chunk_data[key]
|
||||
elif key == "excalidrawDiagram":
|
||||
self.generated_excalidraw_diagram = chunk_data[key]
|
||||
|
||||
def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]:
|
||||
if "image" in json_data or "details" in json_data:
|
||||
|
@ -1853,7 +1886,14 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
|
|||
if buffer:
|
||||
processor.process_message_chunk(buffer)
|
||||
|
||||
return {"response": processor.raw_response, "references": processor.references, "usage": processor.usage}
|
||||
return {
|
||||
"response": processor.raw_response,
|
||||
"references": processor.references,
|
||||
"usage": processor.usage,
|
||||
"images": processor.generated_images,
|
||||
"files": processor.generated_files,
|
||||
"excalidrawDiagram": processor.generated_excalidraw_diagram,
|
||||
}
|
||||
|
||||
|
||||
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
|
||||
|
|
|
@ -22,7 +22,6 @@ from khoj.processor.conversation.offline.chat_model import (
|
|||
filter_questions,
|
||||
)
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.utils.constants import default_offline_chat_models
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ from freezegun import freeze_time
|
|||
|
||||
from khoj.database.models import Agent, Entry, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_api_key
|
||||
|
||||
# Initialize variables for tests
|
||||
|
|
Loading…
Reference in a new issue