diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 383fc536..57657ef1 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -61,6 +61,14 @@ let city = null; let countryName = null; let timezone = null; + let chatMessageState = { + newResponseTextEl: null, + newResponseEl: null, + loadingEllipsis: null, + references: {}, + rawResponse: "", + isVoice: false, + } fetch("https://ipapi.co/json") .then(response => response.json()) @@ -75,10 +83,9 @@ return; }); - async function chat() { - // Extract required fields for search from form + async function chat(isVoice=false) { + // Extract chat message from chat input form let query = document.getElementById("chat-input").value.trim(); - let resultsCount = localStorage.getItem("khojResultsCount") || 5; console.log(`Query: ${query}`); // Short circuit on empty query @@ -106,9 +113,6 @@ await refreshChatSessionsPanel(); } - // Generate backend API URL to execute query - let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`; - let newResponseEl = document.createElement("div"); newResponseEl.classList.add("chat-message", "khoj"); newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); @@ -119,25 +123,7 @@ newResponseEl.appendChild(newResponseTextEl); // Temporary status message to indicate that Khoj is thinking - let loadingEllipsis = document.createElement("div"); - loadingEllipsis.classList.add("lds-ellipsis"); - - let firstEllipsis = document.createElement("div"); - firstEllipsis.classList.add("lds-ellipsis-item"); - - let secondEllipsis = document.createElement("div"); - secondEllipsis.classList.add("lds-ellipsis-item"); - - let thirdEllipsis = document.createElement("div"); - thirdEllipsis.classList.add("lds-ellipsis-item"); - - let fourthEllipsis = document.createElement("div"); - fourthEllipsis.classList.add("lds-ellipsis-item"); - - loadingEllipsis.appendChild(firstEllipsis); - loadingEllipsis.appendChild(secondEllipsis); - loadingEllipsis.appendChild(thirdEllipsis); - loadingEllipsis.appendChild(fourthEllipsis); + let loadingEllipsis = createLoadingEllipsis(); newResponseTextEl.appendChild(loadingEllipsis); document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; @@ -148,107 +134,36 @@ let chatInput = document.getElementById("chat-input"); chatInput.classList.remove("option-enabled"); + // Setup chat message state + chatMessageState = { + newResponseTextEl, + newResponseEl, + loadingEllipsis, + references: {}, + rawResponse: "", + rawQuery: query, + isVoice: isVoice, + } + // Call Khoj chat API - let response = await fetch(chatApi, { headers }); - let rawResponse = ""; - let references = null; - const contentType = response.headers.get("content-type"); + let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationID}&stream=true&client=desktop`; + chatApi += (!!region && !!city && !!countryName && !!timezone) + ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` + : ''; - if (contentType === "application/json") { - // Handle JSON response - try { - const responseAsJson = await response.json(); - if (responseAsJson.image) { - // If response has image field, response is a generated image. - if (responseAsJson.intentType === "text-to-image") { - rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; - } else if (responseAsJson.intentType === "text-to-image2") { - rawResponse += `![${query}](${responseAsJson.image})`; - } else if (responseAsJson.intentType === "text-to-image-v3") { - rawResponse += `![${query}](data:image/webp;base64,${responseAsJson.image})`; - } - const inferredQueries = responseAsJson.inferredQueries?.[0]; - if (inferredQueries) { - rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`; - } - } - if (responseAsJson.context) { - const rawReferenceAsJson = responseAsJson.context; - references = createReferenceSection(rawReferenceAsJson); - } - if (responseAsJson.detail) { - // If response has detail field, response is an error message. - rawResponse += responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - newResponseTextEl.innerHTML = ""; - newResponseTextEl.appendChild(formatHTMLMessage(rawResponse)); + const response = await fetch(chatApi, { headers }); - if (references != null) { - newResponseTextEl.appendChild(references); - } - - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); - } - } else { - // Handle streamed response of type text/event-stream or text/plain - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let references = {}; - - readStream(); - - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - if (references != {}) { - newResponseTextEl.appendChild(createReferenceSection(references)); - } - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); - return; - } - - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); - - if (chunk.includes("### compiled references:")) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; - newResponseTextEl.innerHTML = ""; - newResponseTextEl.appendChild(formatHTMLMessage(rawResponse)); - - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - if (rawReferenceAsJson instanceof Array) { - references["notes"] = rawReferenceAsJson; - } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) { - references["online"] = rawReferenceAsJson; - } - readStream(); - } else { - // Display response from Khoj - if (newResponseTextEl.getElementsByClassName("lds-ellipsis").length > 0) { - newResponseTextEl.removeChild(loadingEllipsis); - } - - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - newResponseTextEl.innerHTML = ""; - newResponseTextEl.appendChild(formatHTMLMessage(rawResponse)); - - readStream(); - } - - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }); - } + try { + if (!response.ok) throw new Error(response.statusText); + if (!response.body) throw new Error("Response body is empty"); + // Stream and render chat response + await readChatStream(response); + } catch (err) { + console.error(`Khoj chat response failed with\n${err}`); + if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) + chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); + let errorMsg = "Sorry, unable to get response from Khoj backend ❤️‍🩹. Retry or contact developers for help at team@khoj.dev or on Discord"; + newResponseTextEl.textContent = errorMsg; } } diff --git a/src/interface/desktop/chatutils.js b/src/interface/desktop/chatutils.js index 42cfa986..5213979f 100644 --- a/src/interface/desktop/chatutils.js +++ b/src/interface/desktop/chatutils.js @@ -364,3 +364,194 @@ function createReferenceSection(references, createLinkerSection=false) { return referencesDiv; } + +function createLoadingEllipsis() { + let loadingEllipsis = document.createElement("div"); + loadingEllipsis.classList.add("lds-ellipsis"); + + let firstEllipsis = document.createElement("div"); + firstEllipsis.classList.add("lds-ellipsis-item"); + + let secondEllipsis = document.createElement("div"); + secondEllipsis.classList.add("lds-ellipsis-item"); + + let thirdEllipsis = document.createElement("div"); + thirdEllipsis.classList.add("lds-ellipsis-item"); + + let fourthEllipsis = document.createElement("div"); + fourthEllipsis.classList.add("lds-ellipsis-item"); + + loadingEllipsis.appendChild(firstEllipsis); + loadingEllipsis.appendChild(secondEllipsis); + loadingEllipsis.appendChild(thirdEllipsis); + loadingEllipsis.appendChild(fourthEllipsis); + + return loadingEllipsis; +} + +function handleStreamResponse(newResponseElement, rawResponse, rawQuery, loadingEllipsis, replace=true) { + if (!newResponseElement) return; + // Remove loading ellipsis if it exists + if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) + newResponseElement.removeChild(loadingEllipsis); + // Clear the response element if replace is true + if (replace) newResponseElement.innerHTML = ""; + + // Append response to the response element + newResponseElement.appendChild(formatHTMLMessage(rawResponse, false, replace, rawQuery)); + + // Append loading ellipsis if it exists + if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis); + // Scroll to bottom of chat view + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; +} + +function handleImageResponse(imageJson, rawResponse) { + if (imageJson.image) { + const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image"; + + // If response has image field, response is a generated image. + if (imageJson.intentType === "text-to-image") { + rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`; + } else if (imageJson.intentType === "text-to-image2") { + rawResponse += `![generated_image](${imageJson.image})`; + } else if (imageJson.intentType === "text-to-image-v3") { + rawResponse = `![](data:image/webp;base64,${imageJson.image})`; + } + if (inferredQuery) { + rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; + } + } + + // If response has detail field, response is an error message. + if (imageJson.detail) rawResponse += imageJson.detail; + + return rawResponse; +} + +function finalizeChatBodyResponse(references, newResponseElement) { + if (!!newResponseElement && references != null && Object.keys(references).length > 0) { + newResponseElement.appendChild(createReferenceSection(references)); + } + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input")?.removeAttribute("disabled"); +} + +function convertMessageChunkToJson(rawChunk) { + // Split the chunk into lines + if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) { + try { + let jsonChunk = JSON.parse(rawChunk); + if (!jsonChunk.type) + jsonChunk = {type: 'message', data: jsonChunk}; + return jsonChunk; + } catch (e) { + return {type: 'message', data: rawChunk}; + } + } else if (rawChunk.length > 0) { + return {type: 'message', data: rawChunk}; + } +} + +function processMessageChunk(rawChunk) { + const chunk = convertMessageChunkToJson(rawChunk); + console.debug("Chunk:", chunk); + if (!chunk || !chunk.type) return; + if (chunk.type ==='status') { + console.log(`status: ${chunk.data}`); + const statusMessage = chunk.data; + handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, false); + } else if (chunk.type === 'start_llm_response') { + console.log("Started streaming", new Date()); + } else if (chunk.type === 'end_llm_response') { + console.log("Stopped streaming", new Date()); + + // Automatically respond with voice if the subscribed user has sent voice message + if (chatMessageState.isVoice && "{{ is_active }}" == "True") + textToSpeech(chatMessageState.rawResponse); + + // Append any references after all the data has been streamed + finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); + + const liveQuery = chatMessageState.rawQuery; + // Reset variables + chatMessageState = { + newResponseTextEl: null, + newResponseEl: null, + loadingEllipsis: null, + references: {}, + rawResponse: "", + rawQuery: liveQuery, + isVoice: false, + } + } else if (chunk.type === "references") { + chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext}; + } else if (chunk.type === 'message') { + const chunkData = chunk.data; + if (typeof chunkData === 'object' && chunkData !== null) { + // If chunkData is already a JSON object + handleJsonResponse(chunkData); + } else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) { + // Try process chunk data as if it is a JSON object + try { + const jsonData = JSON.parse(chunkData.trim()); + handleJsonResponse(jsonData); + } catch (e) { + chatMessageState.rawResponse += chunkData; + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } else { + chatMessageState.rawResponse += chunkData; + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } +} + +function handleJsonResponse(jsonData) { + if (jsonData.image || jsonData.detail) { + chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse); + } else if (jsonData.response) { + chatMessageState.rawResponse = jsonData.response; + } + + if (chatMessageState.newResponseTextEl) { + chatMessageState.newResponseTextEl.innerHTML = ""; + chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse)); + } +} + +async function readChatStream(response) { + if (!response.body) return; + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + const eventDelimiter = '␃🔚␗'; + let buffer = ''; + + while (true) { + const { value, done } = await reader.read(); + // If the stream is done + if (done) { + // Process the last chunk + processMessageChunk(buffer); + buffer = ''; + break; + } + + // Read chunk from stream and append it to the buffer + const chunk = decoder.decode(value, { stream: true }); + console.debug("Raw Chunk:", chunk) + // Start buffering chunks until complete event is received + buffer += chunk; + + // Once the buffer contains a complete event + let newEventIndex; + while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) { + // Extract the event from the buffer + const event = buffer.slice(0, newEventIndex); + buffer = buffer.slice(newEventIndex + eventDelimiter.length); + + // Process the event + if (event) processMessageChunk(event); + } + } +} diff --git a/src/interface/desktop/shortcut.html b/src/interface/desktop/shortcut.html index 4af26f0d..52207f20 100644 --- a/src/interface/desktop/shortcut.html +++ b/src/interface/desktop/shortcut.html @@ -346,7 +346,7 @@ inp.focus(); } - async function chat() { + async function chat(isVoice=false) { //set chat body to empty let chatBody = document.getElementById("chat-body"); chatBody.innerHTML = ""; @@ -375,9 +375,6 @@ chat_body.dataset.conversationId = conversationID; } - // Generate backend API URL to execute query - let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`; - let newResponseEl = document.createElement("div"); newResponseEl.classList.add("chat-message", "khoj"); newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); @@ -388,128 +385,41 @@ newResponseEl.appendChild(newResponseTextEl); // Temporary status message to indicate that Khoj is thinking - let loadingEllipsis = document.createElement("div"); - loadingEllipsis.classList.add("lds-ellipsis"); - - let firstEllipsis = document.createElement("div"); - firstEllipsis.classList.add("lds-ellipsis-item"); - - let secondEllipsis = document.createElement("div"); - secondEllipsis.classList.add("lds-ellipsis-item"); - - let thirdEllipsis = document.createElement("div"); - thirdEllipsis.classList.add("lds-ellipsis-item"); - - let fourthEllipsis = document.createElement("div"); - fourthEllipsis.classList.add("lds-ellipsis-item"); - - loadingEllipsis.appendChild(firstEllipsis); - loadingEllipsis.appendChild(secondEllipsis); - loadingEllipsis.appendChild(thirdEllipsis); - loadingEllipsis.appendChild(fourthEllipsis); - - newResponseTextEl.appendChild(loadingEllipsis); + let loadingEllipsis = createLoadingEllipsis(); document.body.scrollTop = document.getElementById("chat-body").scrollHeight; - // Call Khoj chat API - let response = await fetch(chatApi, { headers }); - let rawResponse = ""; - let references = null; - const contentType = response.headers.get("content-type"); toggleLoading(); - if (contentType === "application/json") { - // Handle JSON response - try { - const responseAsJson = await response.json(); - if (responseAsJson.image) { - // If response has image field, response is a generated image. - if (responseAsJson.intentType === "text-to-image") { - rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; - } else if (responseAsJson.intentType === "text-to-image2") { - rawResponse += `![${query}](${responseAsJson.image})`; - } else if (responseAsJson.intentType === "text-to-image-v3") { - rawResponse += `![${query}](data:image/webp;base64,${responseAsJson.image})`; - } - const inferredQueries = responseAsJson.inferredQueries?.[0]; - if (inferredQueries) { - rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`; - } - } - if (responseAsJson.context) { - const rawReferenceAsJson = responseAsJson.context; - references = createReferenceSection(rawReferenceAsJson, createLinkerSection=true); - } - if (responseAsJson.detail) { - // If response has detail field, response is an error message. - rawResponse += responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - newResponseTextEl.innerHTML = ""; - newResponseTextEl.appendChild(formatHTMLMessage(rawResponse)); - if (references != null) { - newResponseTextEl.appendChild(references); - } + // Setup chat message state + chatMessageState = { + newResponseTextEl, + newResponseEl, + loadingEllipsis, + references: {}, + rawResponse: "", + rawQuery: query, + isVoice: isVoice, + } - document.body.scrollTop = document.getElementById("chat-body").scrollHeight; - } - } else { - // Handle streamed response of type text/event-stream or text/plain - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let references = {}; + // Construct API URL to execute chat query + let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationID}&stream=true&client=desktop`; + chatApi += (!!region && !!city && !!countryName && !!timezone) + ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` + : ''; - readStream(); + const response = await fetch(chatApi, { headers }); - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - if (references != {}) { - newResponseTextEl.appendChild(createReferenceSection(references, createLinkerSection=true)); - } - document.body.scrollTop = document.getElementById("chat-body").scrollHeight; - return; - } - - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); - - if (chunk.includes("### compiled references:")) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; - newResponseTextEl.innerHTML = ""; - newResponseTextEl.appendChild(formatHTMLMessage(rawResponse)); - - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - if (rawReferenceAsJson instanceof Array) { - references["notes"] = rawReferenceAsJson; - } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) { - references["online"] = rawReferenceAsJson; - } - readStream(); - } else { - // Display response from Khoj - if (newResponseTextEl.getElementsByClassName("lds-ellipsis").length > 0) { - newResponseTextEl.removeChild(loadingEllipsis); - } - - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - newResponseTextEl.innerHTML = ""; - newResponseTextEl.appendChild(formatHTMLMessage(rawResponse)); - - readStream(); - } - - // Scroll to bottom of chat window as chat response is streamed - document.body.scrollTop = document.getElementById("chat-body").scrollHeight; - }); - } + try { + if (!response.ok) throw new Error(response.statusText); + if (!response.body) throw new Error("Response body is empty"); + // Stream and render chat response + await readChatStream(response); + } catch (err) { + console.error(`Khoj chat response failed with\n${err}`); + if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) + chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); + let errorMsg = "Sorry, unable to get response from Khoj backend ❤️‍🩹. Retry or contact developers for help at team@khoj.dev or on Discord"; + newResponseTextEl.textContent = errorMsg; } document.body.scrollTop = document.getElementById("chat-body").scrollHeight; } diff --git a/src/interface/desktop/utils.js b/src/interface/desktop/utils.js index c880a7cd..af0234ea 100644 --- a/src/interface/desktop/utils.js +++ b/src/interface/desktop/utils.js @@ -34,8 +34,8 @@ function toggleNavMenu() { document.addEventListener('click', function(event) { let menu = document.getElementById("khoj-nav-menu"); let menuContainer = document.getElementById("khoj-nav-menu-container"); - let isClickOnMenu = menuContainer.contains(event.target) || menuContainer === event.target; - if (isClickOnMenu === false && menu.classList.contains("show")) { + let isClickOnMenu = menuContainer?.contains(event.target) || menuContainer === event.target; + if (menu && isClickOnMenu === false && menu.classList.contains("show")) { menu.classList.remove("show"); } }); diff --git a/src/interface/obsidian/src/chat_view.ts b/src/interface/obsidian/src/chat_view.ts index b8d95d6b..cbd0f7bf 100644 --- a/src/interface/obsidian/src/chat_view.ts +++ b/src/interface/obsidian/src/chat_view.ts @@ -12,6 +12,25 @@ export interface ChatJsonResult { inferredQueries?: string[]; } +interface ChunkResult { + objects: string[]; + remainder: string; +} + +interface MessageChunk { + type: string; + data: any; +} + +interface ChatMessageState { + newResponseTextEl: HTMLElement | null; + newResponseEl: HTMLElement | null; + loadingEllipsis: HTMLElement | null; + references: any; + rawResponse: string; + rawQuery: string; + isVoice: boolean; +} interface Location { region: string; @@ -26,6 +45,7 @@ export class KhojChatView extends KhojPaneView { waitingForLocation: boolean; location: Location; keyPressTimeout: NodeJS.Timeout | null = null; + chatMessageState: ChatMessageState; constructor(leaf: WorkspaceLeaf, setting: KhojSetting) { super(leaf, setting); @@ -409,16 +429,15 @@ export class KhojChatView extends KhojPaneView { message = DOMPurify.sanitize(message); // Convert the message to html, sanitize the message html and render it to the real DOM - let chat_message_body_text_el = this.contentEl.createDiv(); - chat_message_body_text_el.className = "chat-message-text-response"; - chat_message_body_text_el.innerHTML = this.markdownTextToSanitizedHtml(message, this); + let chatMessageBodyTextEl = this.contentEl.createDiv(); + chatMessageBodyTextEl.innerHTML = this.markdownTextToSanitizedHtml(message, this); // Add a copy button to each chat message, if it doesn't already exist if (willReplace === true) { - this.renderActionButtons(message, chat_message_body_text_el); + this.renderActionButtons(message, chatMessageBodyTextEl); } - return chat_message_body_text_el; + return chatMessageBodyTextEl; } markdownTextToSanitizedHtml(markdownText: string, component: ItemView): string { @@ -502,23 +521,23 @@ export class KhojChatView extends KhojPaneView { class: `khoj-chat-message ${sender}` }, }) - let chat_message_body_el = chatMessageEl.createDiv(); - chat_message_body_el.addClasses(["khoj-chat-message-text", sender]); - let chat_message_body_text_el = chat_message_body_el.createDiv(); + let chatMessageBodyEl = chatMessageEl.createDiv(); + chatMessageBodyEl.addClasses(["khoj-chat-message-text", sender]); + let chatMessageBodyTextEl = chatMessageBodyEl.createDiv(); // Sanitize the markdown to render message = DOMPurify.sanitize(message); if (raw) { - chat_message_body_text_el.innerHTML = message; + chatMessageBodyTextEl.innerHTML = message; } else { // @ts-ignore - chat_message_body_text_el.innerHTML = this.markdownTextToSanitizedHtml(message, this); + chatMessageBodyTextEl.innerHTML = this.markdownTextToSanitizedHtml(message, this); } // Add action buttons to each chat message element if (willReplace === true) { - this.renderActionButtons(message, chat_message_body_text_el); + this.renderActionButtons(message, chatMessageBodyTextEl); } // Remove user-select: none property to make text selectable @@ -531,42 +550,38 @@ export class KhojChatView extends KhojPaneView { } createKhojResponseDiv(dt?: Date): HTMLDivElement { - let message_time = this.formatDate(dt ?? new Date()); + let messageTime = this.formatDate(dt ?? new Date()); // Append message to conversation history HTML element. // The chat logs should display above the message input box to follow standard UI semantics - let chat_body_el = this.contentEl.getElementsByClassName("khoj-chat-body")[0]; - let chat_message_el = chat_body_el.createDiv({ + let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0]; + let chatMessageEl = chatBodyEl.createDiv({ attr: { - "data-meta": `🏮 Khoj at ${message_time}`, + "data-meta": `🏮 Khoj at ${messageTime}`, class: `khoj-chat-message khoj` }, - }).createDiv({ - attr: { - class: `khoj-chat-message-text khoj` - }, - }).createDiv(); + }) // Scroll to bottom after inserting chat messages this.scrollChatToBottom(); - return chat_message_el; + return chatMessageEl; } async renderIncrementalMessage(htmlElement: HTMLDivElement, additionalMessage: string) { - this.result += additionalMessage; + this.chatMessageState.rawResponse += additionalMessage; htmlElement.innerHTML = ""; // Sanitize the markdown to render - this.result = DOMPurify.sanitize(this.result); + this.chatMessageState.rawResponse = DOMPurify.sanitize(this.chatMessageState.rawResponse); // @ts-ignore - htmlElement.innerHTML = this.markdownTextToSanitizedHtml(this.result, this); + htmlElement.innerHTML = this.markdownTextToSanitizedHtml(this.chatMessageState.rawResponse, this); // Render action buttons for the message - this.renderActionButtons(this.result, htmlElement); + this.renderActionButtons(this.chatMessageState.rawResponse, htmlElement); // Scroll to bottom of modal, till the send message input box this.scrollChatToBottom(); } - renderActionButtons(message: string, chat_message_body_text_el: HTMLElement) { + renderActionButtons(message: string, chatMessageBodyTextEl: HTMLElement) { let copyButton = this.contentEl.createEl('button'); copyButton.classList.add("chat-action-button"); copyButton.title = "Copy Message to Clipboard"; @@ -593,10 +608,10 @@ export class KhojChatView extends KhojPaneView { } // Append buttons to parent element - chat_message_body_text_el.append(copyButton, pasteToFile); + chatMessageBodyTextEl.append(copyButton, pasteToFile); if (speechButton) { - chat_message_body_text_el.append(speechButton); + chatMessageBodyTextEl.append(speechButton); } } @@ -854,35 +869,122 @@ export class KhojChatView extends KhojPaneView { return true; } - async readChatStream(response: Response, responseElement: HTMLDivElement, isVoice: boolean = false): Promise { + convertMessageChunkToJson(rawChunk: string): MessageChunk { + if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) { + try { + let jsonChunk = JSON.parse(rawChunk); + if (!jsonChunk.type) + jsonChunk = {type: 'message', data: jsonChunk}; + return jsonChunk; + } catch (e) { + return {type: 'message', data: rawChunk}; + } + } else if (rawChunk.length > 0) { + return {type: 'message', data: rawChunk}; + } + return {type: '', data: ''}; + } + + processMessageChunk(rawChunk: string): void { + const chunk = this.convertMessageChunkToJson(rawChunk); + console.debug("Chunk:", chunk); + if (!chunk || !chunk.type) return; + if (chunk.type === 'status') { + console.log(`status: ${chunk.data}`); + const statusMessage = chunk.data; + this.handleStreamResponse(this.chatMessageState.newResponseTextEl, statusMessage, this.chatMessageState.loadingEllipsis, false); + } else if (chunk.type === 'start_llm_response') { + console.log("Started streaming", new Date()); + } else if (chunk.type === 'end_llm_response') { + console.log("Stopped streaming", new Date()); + + // Automatically respond with voice if the subscribed user has sent voice message + if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active) + this.textToSpeech(this.chatMessageState.rawResponse); + + // Append any references after all the data has been streamed + this.finalizeChatBodyResponse(this.chatMessageState.references, this.chatMessageState.newResponseTextEl); + + const liveQuery = this.chatMessageState.rawQuery; + // Reset variables + this.chatMessageState = { + newResponseTextEl: null, + newResponseEl: null, + loadingEllipsis: null, + references: {}, + rawResponse: "", + rawQuery: liveQuery, + isVoice: false, + }; + } else if (chunk.type === "references") { + this.chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext}; + } else if (chunk.type === 'message') { + const chunkData = chunk.data; + if (typeof chunkData === 'object' && chunkData !== null) { + // If chunkData is already a JSON object + this.handleJsonResponse(chunkData); + } else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) { + // Try process chunk data as if it is a JSON object + try { + const jsonData = JSON.parse(chunkData.trim()); + this.handleJsonResponse(jsonData); + } catch (e) { + this.chatMessageState.rawResponse += chunkData; + this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis); + } + } else { + this.chatMessageState.rawResponse += chunkData; + this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis); + } + } + } + + handleJsonResponse(jsonData: any): void { + if (jsonData.image || jsonData.detail) { + this.chatMessageState.rawResponse = this.handleImageResponse(jsonData, this.chatMessageState.rawResponse); + } else if (jsonData.response) { + this.chatMessageState.rawResponse = jsonData.response; + } + + if (this.chatMessageState.newResponseTextEl) { + this.chatMessageState.newResponseTextEl.innerHTML = ""; + this.chatMessageState.newResponseTextEl.appendChild(this.formatHTMLMessage(this.chatMessageState.rawResponse)); + } + } + + async readChatStream(response: Response): Promise { // Exit if response body is empty if (response.body == null) return; const reader = response.body.getReader(); const decoder = new TextDecoder(); + const eventDelimiter = '␃🔚␗'; + let buffer = ''; while (true) { const { value, done } = await reader.read(); if (done) { - // Automatically respond with voice if the subscribed user has sent voice message - if (isVoice && this.setting.userInfo?.is_active) this.textToSpeech(this.result); + this.processMessageChunk(buffer); + buffer = ''; // Break if the stream is done break; } - let responseText = decoder.decode(value); - if (responseText.includes("### compiled references:")) { - // Render any references used to generate the response - const [additionalResponse, rawReference] = responseText.split("### compiled references:", 2); - await this.renderIncrementalMessage(responseElement, additionalResponse); + const chunk = decoder.decode(value, { stream: true }); + console.debug("Raw Chunk:", chunk) + // Start buffering chunks until complete event is received + buffer += chunk; - const rawReferenceAsJson = JSON.parse(rawReference); - let references = this.extractReferences(rawReferenceAsJson); - responseElement.appendChild(this.createReferenceSection(references)); - } else { - // Render incremental chat response - await this.renderIncrementalMessage(responseElement, responseText); + // Once the buffer contains a complete event + let newEventIndex; + while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) { + // Extract the event from the buffer + const event = buffer.slice(0, newEventIndex); + buffer = buffer.slice(newEventIndex + eventDelimiter.length); + + // Process the event + if (event) this.processMessageChunk(event); } } } @@ -895,83 +997,59 @@ export class KhojChatView extends KhojPaneView { let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0] as HTMLElement; this.renderMessage(chatBodyEl, query, "you"); - let conversationID = chatBodyEl.dataset.conversationId; - if (!conversationID) { + let conversationId = chatBodyEl.dataset.conversationId; + if (!conversationId) { let chatUrl = `${this.setting.khojUrl}/api/chat/sessions?client=obsidian`; let response = await fetch(chatUrl, { method: "POST", headers: { "Authorization": `Bearer ${this.setting.khojApiKey}` }, }); let data = await response.json(); - conversationID = data.conversation_id; - chatBodyEl.dataset.conversationId = conversationID; + conversationId = data.conversation_id; + chatBodyEl.dataset.conversationId = conversationId; } // Get chat response from Khoj backend let encodedQuery = encodeURIComponent(query); - let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian&stream=true®ion=${this.location.region}&city=${this.location.city}&country=${this.location.countryName}&timezone=${this.location.timezone}`; - let responseElement = this.createKhojResponseDiv(); + let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&conversation_id=${conversationId}&n=${this.setting.resultsCount}&stream=true&client=obsidian`; + if (!!this.location) chatUrl += `®ion=${this.location.region}&city=${this.location.city}&country=${this.location.countryName}&timezone=${this.location.timezone}`; + + let newResponseEl = this.createKhojResponseDiv(); + let newResponseTextEl = newResponseEl.createDiv(); + newResponseTextEl.classList.add("khoj-chat-message-text", "khoj"); // Temporary status message to indicate that Khoj is thinking - this.result = ""; let loadingEllipsis = this.createLoadingEllipse(); - responseElement.appendChild(loadingEllipsis); + newResponseTextEl.appendChild(loadingEllipsis); + + // Set chat message state + this.chatMessageState = { + newResponseEl: newResponseEl, + newResponseTextEl: newResponseTextEl, + loadingEllipsis: loadingEllipsis, + references: {}, + rawQuery: query, + rawResponse: "", + isVoice: isVoice, + }; let response = await fetch(chatUrl, { method: "GET", headers: { - "Content-Type": "text/event-stream", + "Content-Type": "text/plain", "Authorization": `Bearer ${this.setting.khojApiKey}`, }, }) try { - if (response.body === null) { - throw new Error("Response body is null"); - } + if (response.body === null) throw new Error("Response body is null"); - // Clear loading status message - if (responseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) { - responseElement.removeChild(loadingEllipsis); - } - - // Reset collated chat result to empty string - this.result = ""; - responseElement.innerHTML = ""; - if (response.headers.get("content-type") === "application/json") { - let responseText = "" - try { - const responseAsJson = await response.json() as ChatJsonResult; - if (responseAsJson.image) { - // If response has image field, response is a generated image. - if (responseAsJson.intentType === "text-to-image") { - responseText += `![${query}](data:image/png;base64,${responseAsJson.image})`; - } else if (responseAsJson.intentType === "text-to-image2") { - responseText += `![${query}](${responseAsJson.image})`; - } else if (responseAsJson.intentType === "text-to-image-v3") { - responseText += `![${query}](data:image/webp;base64,${responseAsJson.image})`; - } - const inferredQuery = responseAsJson.inferredQueries?.[0]; - if (inferredQuery) { - responseText += `\n\n**Inferred Query**:\n\n${inferredQuery}`; - } - } else if (responseAsJson.detail) { - responseText = responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - responseText = await response.text(); - } finally { - await this.renderIncrementalMessage(responseElement, responseText); - } - } else { - // Stream and render chat response - await this.readChatStream(response, responseElement, isVoice); - } + // Stream and render chat response + await this.readChatStream(response); } catch (err) { - console.log(`Khoj chat response failed with\n${err}`); + console.error(`Khoj chat response failed with\n${err}`); let errorMsg = "Sorry, unable to get response from Khoj backend ❤️‍🩹. Retry or contact developers for help at team@khoj.dev or on Discord"; - responseElement.innerHTML = errorMsg + newResponseTextEl.textContent = errorMsg; } } @@ -1196,30 +1274,21 @@ export class KhojChatView extends KhojPaneView { handleStreamResponse(newResponseElement: HTMLElement | null, rawResponse: string, loadingEllipsis: HTMLElement | null, replace = true) { if (!newResponseElement) return; - if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) { + // Remove loading ellipsis if it exists + if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) newResponseElement.removeChild(loadingEllipsis); - } - if (replace) { - newResponseElement.innerHTML = ""; - } + // Clear the response element if replace is true + if (replace) newResponseElement.innerHTML = ""; + + // Append response to the response element newResponseElement.appendChild(this.formatHTMLMessage(rawResponse, false, replace)); + + // Append loading ellipsis if it exists + if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis); + // Scroll to bottom of chat view this.scrollChatToBottom(); } - handleCompiledReferences(rawResponseElement: HTMLElement | null, chunk: string, references: any, rawResponse: string) { - if (!rawResponseElement || !chunk) return { rawResponse, references }; - - const [additionalResponse, rawReference] = chunk.split("### compiled references:", 2); - rawResponse += additionalResponse; - rawResponseElement.innerHTML = ""; - rawResponseElement.appendChild(this.formatHTMLMessage(rawResponse)); - - const rawReferenceAsJson = JSON.parse(rawReference); - references = this.extractReferences(rawReferenceAsJson); - - return { rawResponse, references }; - } - handleImageResponse(imageJson: any, rawResponse: string) { if (imageJson.image) { const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image"; @@ -1236,33 +1305,10 @@ export class KhojChatView extends KhojPaneView { rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; } } - let references = {}; - if (imageJson.context && imageJson.context.length > 0) { - references = this.extractReferences(imageJson.context); - } - if (imageJson.detail) { - // If response has detail field, response is an error message. - rawResponse += imageJson.detail; - } - return { rawResponse, references }; - } + // If response has detail field, response is an error message. + if (imageJson.detail) rawResponse += imageJson.detail; - extractReferences(rawReferenceAsJson: any): object { - let references: any = {}; - if (rawReferenceAsJson instanceof Array) { - references["notes"] = rawReferenceAsJson; - } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) { - references["online"] = rawReferenceAsJson; - } - return references; - } - - addMessageToChatBody(rawResponse: string, newResponseElement: HTMLElement | null, references: any) { - if (!newResponseElement) return; - newResponseElement.innerHTML = ""; - newResponseElement.appendChild(this.formatHTMLMessage(rawResponse)); - - this.finalizeChatBodyResponse(references, newResponseElement); + return rawResponse; } finalizeChatBodyResponse(references: object, newResponseElement: HTMLElement | null) { diff --git a/src/interface/obsidian/styles.css b/src/interface/obsidian/styles.css index afd8fd19..42c1b3ce 100644 --- a/src/interface/obsidian/styles.css +++ b/src/interface/obsidian/styles.css @@ -85,6 +85,12 @@ If your plugin does not need CSS, delete this file. margin-left: auto; white-space: pre-line; } +/* Override white-space for ul, ol, li under khoj-chat-message-text.khoj */ +.khoj-chat-message-text.khoj ul, +.khoj-chat-message-text.khoj ol, +.khoj-chat-message-text.khoj li { + white-space: normal; +} /* add left protrusion to khoj chat bubble */ .khoj-chat-message-text.khoj:after { content: ''; diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index a0a13e8c..0e50169e 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -680,34 +680,18 @@ class ConversationAdapters: async def aget_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, title: str = None ) -> Optional[Conversation]: + query = Conversation.objects.filter(user=user, client=client_application).prefetch_related("agent") + if conversation_id: - return ( - await Conversation.objects.filter(user=user, client=client_application, id=conversation_id) - .prefetch_related("agent") - .afirst() - ) + return await query.filter(id=conversation_id).afirst() elif title: - return ( - await Conversation.objects.filter(user=user, client=client_application, title=title) - .prefetch_related("agent") - .afirst() - ) - else: - conversation = ( - Conversation.objects.filter(user=user, client=client_application) - .prefetch_related("agent") - .order_by("-updated_at") - ) + return await query.filter(title=title).afirst() - if await conversation.aexists(): - return await conversation.prefetch_related("agent").afirst() + conversation = await query.order_by("-updated_at").afirst() - return await ( - Conversation.objects.filter(user=user, client=client_application) - .prefetch_related("agent") - .order_by("-updated_at") - .afirst() - ) or await Conversation.objects.prefetch_related("agent").acreate(user=user, client=client_application) + return conversation or await Conversation.objects.prefetch_related("agent").acreate( + user=user, client=client_application + ) @staticmethod async def adelete_conversation_by_user( diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index ad8ced27..024af9ad 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -74,14 +74,13 @@ To get started, just start typing below. You can also type / to see a list of co }, 1000); }); } - var websocket = null; + let region = null; let city = null; let countryName = null; let timezone = null; let waitingForLocation = true; - - let websocketState = { + let chatMessageState = { newResponseTextEl: null, newResponseEl: null, loadingEllipsis: null, @@ -105,7 +104,7 @@ To get started, just start typing below. You can also type / to see a list of co .finally(() => { console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone); waitingForLocation = false; - setupWebSocket(); + initMessageState(); }); function formatDate(date) { @@ -599,13 +598,8 @@ To get started, just start typing below. You can also type / to see a list of co } async function chat(isVoice=false) { - if (websocket) { - sendMessageViaWebSocket(isVoice); - return; - } - - let query = document.getElementById("chat-input").value.trim(); - let resultsCount = localStorage.getItem("khojResultsCount") || 5; + // Extract chat message from chat input form + var query = document.getElementById("chat-input").value.trim(); console.log(`Query: ${query}`); // Short circuit on empty query @@ -624,31 +618,30 @@ To get started, just start typing below. You can also type / to see a list of co document.getElementById("chat-input").value = ""; autoResize(); document.getElementById("chat-input").setAttribute("disabled", "disabled"); - let chat_body = document.getElementById("chat-body"); - - let conversationID = chat_body.dataset.conversationId; + let chatBody = document.getElementById("chat-body"); + let conversationID = chatBody.dataset.conversationId; if (!conversationID) { - let response = await fetch('/api/chat/sessions', { method: "POST" }); + let response = await fetch(`${hostURL}/api/chat/sessions`, { method: "POST" }); let data = await response.json(); conversationID = data.conversation_id; - chat_body.dataset.conversationId = conversationID; - refreshChatSessionsPanel(); + chatBody.dataset.conversationId = conversationID; + await refreshChatSessionsPanel(); } - let new_response = document.createElement("div"); - new_response.classList.add("chat-message", "khoj"); - new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); - chat_body.appendChild(new_response); + let newResponseEl = document.createElement("div"); + newResponseEl.classList.add("chat-message", "khoj"); + newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); + chatBody.appendChild(newResponseEl); - let newResponseText = document.createElement("div"); - newResponseText.classList.add("chat-message-text", "khoj"); - new_response.appendChild(newResponseText); + let newResponseTextEl = document.createElement("div"); + newResponseTextEl.classList.add("chat-message-text", "khoj"); + newResponseEl.appendChild(newResponseTextEl); // Temporary status message to indicate that Khoj is thinking let loadingEllipsis = createLoadingEllipse(); - newResponseText.appendChild(loadingEllipsis); + newResponseTextEl.appendChild(loadingEllipsis); document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; let chatTooltip = document.getElementById("chat-tooltip"); @@ -657,65 +650,38 @@ To get started, just start typing below. You can also type / to see a list of co let chatInput = document.getElementById("chat-input"); chatInput.classList.remove("option-enabled"); - // Generate backend API URL to execute query - let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`; - - // Call specified Khoj API - let response = await fetch(url); - let rawResponse = ""; - let references = null; - const contentType = response.headers.get("content-type"); - - if (contentType === "application/json") { - // Handle JSON response - try { - const responseAsJson = await response.json(); - if (responseAsJson.image || responseAsJson.detail) { - ({rawResponse, references } = handleImageResponse(responseAsJson, rawResponse)); - } else { - rawResponse = responseAsJson.response; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - addMessageToChatBody(rawResponse, newResponseText, references); - } - } else { - // Handle streamed response of type text/event-stream or text/plain - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let references = {}; - - readStream(); - - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - finalizeChatBodyResponse(references, newResponseText); - return; - } - - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); - - if (chunk.includes("### compiled references:")) { - ({ rawResponse, references } = handleCompiledReferences(newResponseText, chunk, references, rawResponse)); - readStream(); - } else { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - handleStreamResponse(newResponseText, rawResponse, query, loadingEllipsis); - readStream(); - } - }); - - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }; + // Setup chat message state + chatMessageState = { + newResponseTextEl, + newResponseEl, + loadingEllipsis, + references: {}, + rawResponse: "", + rawQuery: query, + isVoice: isVoice, } - }; + + // Call Khoj chat API + let chatApi = `/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationID}&stream=true&client=web`; + chatApi += (!!region && !!city && !!countryName && !!timezone) + ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` + : ''; + + const response = await fetch(chatApi); + + try { + if (!response.ok) throw new Error(response.statusText); + if (!response.body) throw new Error("Response body is empty"); + // Stream and render chat response + await readChatStream(response); + } catch (err) { + console.error(`Khoj chat response failed with\n${err}`); + if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) + chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); + let errorMsg = "Sorry, unable to get response from Khoj backend ❤️‍🩹. Retry or contact developers for help at team@khoj.dev or on Discord"; + newResponseTextEl.innerHTML = errorMsg; + } + } function createLoadingEllipse() { // Temporary status message to indicate that Khoj is thinking @@ -743,32 +709,22 @@ To get started, just start typing below. You can also type / to see a list of co } function handleStreamResponse(newResponseElement, rawResponse, rawQuery, loadingEllipsis, replace=true) { - if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) { + if (!newResponseElement) return; + // Remove loading ellipsis if it exists + if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) newResponseElement.removeChild(loadingEllipsis); - } - if (replace) { - newResponseElement.innerHTML = ""; - } + // Clear the response element if replace is true + if (replace) newResponseElement.innerHTML = ""; + + // Append response to the response element newResponseElement.appendChild(formatHTMLMessage(rawResponse, false, replace, rawQuery)); + + // Append loading ellipsis if it exists + if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis); + // Scroll to bottom of chat view document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; } - function handleCompiledReferences(rawResponseElement, chunk, references, rawResponse) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; - rawResponseElement.innerHTML = ""; - rawResponseElement.appendChild(formatHTMLMessage(rawResponse)); - - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - if (rawReferenceAsJson instanceof Array) { - references["notes"] = rawReferenceAsJson; - } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) { - references["online"] = rawReferenceAsJson; - } - return { rawResponse, references }; - } - function handleImageResponse(imageJson, rawResponse) { if (imageJson.image) { const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image"; @@ -785,35 +741,139 @@ To get started, just start typing below. You can also type / to see a list of co rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; } } - let references = {}; - if (imageJson.context && imageJson.context.length > 0) { - const rawReferenceAsJson = imageJson.context; - if (rawReferenceAsJson instanceof Array) { - references["notes"] = rawReferenceAsJson; - } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) { - references["online"] = rawReferenceAsJson; - } - } - if (imageJson.detail) { - // If response has detail field, response is an error message. - rawResponse += imageJson.detail; - } - return { rawResponse, references }; - } - function addMessageToChatBody(rawResponse, newResponseElement, references) { - newResponseElement.innerHTML = ""; - newResponseElement.appendChild(formatHTMLMessage(rawResponse)); + // If response has detail field, response is an error message. + if (imageJson.detail) rawResponse += imageJson.detail; - finalizeChatBodyResponse(references, newResponseElement); + return rawResponse; } function finalizeChatBodyResponse(references, newResponseElement) { - if (references != null && Object.keys(references).length > 0) { + if (!!newResponseElement && references != null && Object.keys(references).length > 0) { newResponseElement.appendChild(createReferenceSection(references)); } document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); + document.getElementById("chat-input")?.removeAttribute("disabled"); + } + + function convertMessageChunkToJson(rawChunk) { + // Split the chunk into lines + console.debug("Raw Event:", rawChunk); + if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) { + try { + let jsonChunk = JSON.parse(rawChunk); + if (!jsonChunk.type) + jsonChunk = {type: 'message', data: jsonChunk}; + return jsonChunk; + } catch (e) { + return {type: 'message', data: rawChunk}; + } + } else if (rawChunk.length > 0) { + return {type: 'message', data: rawChunk}; + } + } + + function processMessageChunk(rawChunk) { + const chunk = convertMessageChunkToJson(rawChunk); + console.debug("Json Event:", chunk); + if (!chunk || !chunk.type) return; + if (chunk.type ==='status') { + console.log(`status: ${chunk.data}`); + const statusMessage = chunk.data; + handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, false); + } else if (chunk.type === 'start_llm_response') { + console.log("Started streaming", new Date()); + } else if (chunk.type === 'end_llm_response') { + console.log("Stopped streaming", new Date()); + + // Automatically respond with voice if the subscribed user has sent voice message + if (chatMessageState.isVoice && "{{ is_active }}" == "True") + textToSpeech(chatMessageState.rawResponse); + + // Append any references after all the data has been streamed + finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); + + const liveQuery = chatMessageState.rawQuery; + // Reset variables + chatMessageState = { + newResponseTextEl: null, + newResponseEl: null, + loadingEllipsis: null, + references: {}, + rawResponse: "", + rawQuery: liveQuery, + isVoice: false, + } + } else if (chunk.type === "references") { + chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext}; + } else if (chunk.type === 'message') { + const chunkData = chunk.data; + if (typeof chunkData === 'object' && chunkData !== null) { + // If chunkData is already a JSON object + handleJsonResponse(chunkData); + } else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) { + // Try process chunk data as if it is a JSON object + try { + const jsonData = JSON.parse(chunkData.trim()); + handleJsonResponse(jsonData); + } catch (e) { + chatMessageState.rawResponse += chunkData; + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } else { + chatMessageState.rawResponse += chunkData; + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } + } + + function handleJsonResponse(jsonData) { + if (jsonData.image || jsonData.detail) { + chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse); + } else if (jsonData.response) { + chatMessageState.rawResponse = jsonData.response; + } + + if (chatMessageState.newResponseTextEl) { + chatMessageState.newResponseTextEl.innerHTML = ""; + chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse)); + } + } + + async function readChatStream(response) { + if (!response.body) return; + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + const eventDelimiter = '␃🔚␗'; + let buffer = ''; + + while (true) { + const { value, done } = await reader.read(); + // If the stream is done + if (done) { + // Process the last chunk + processMessageChunk(buffer); + buffer = ''; + break; + } + + // Read chunk from stream and append it to the buffer + const chunk = decoder.decode(value, { stream: true }); + console.debug("Raw Chunk:", chunk) + // Start buffering chunks until complete event is received + buffer += chunk; + + // Once the buffer contains a complete event + let newEventIndex; + while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) { + // Extract the event from the buffer + const event = buffer.slice(0, newEventIndex); + buffer = buffer.slice(newEventIndex + eventDelimiter.length); + + // Process the event + if (event) processMessageChunk(event); + } + } } function incrementalChat(event) { @@ -1069,17 +1129,13 @@ To get started, just start typing below. You can also type / to see a list of co window.onload = loadChat; - function setupWebSocket(isVoice=false) { - let chatBody = document.getElementById("chat-body"); - let wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; - let webSocketUrl = `${wsProtocol}//${window.location.host}/api/chat/ws`; - + function initMessageState(isVoice=false) { if (waitingForLocation) { console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available."); return; } - websocketState = { + chatMessageState = { newResponseTextEl: null, newResponseEl: null, loadingEllipsis: null, @@ -1088,174 +1144,8 @@ To get started, just start typing below. You can also type / to see a list of co rawQuery: "", isVoice: isVoice, } - - if (chatBody.dataset.conversationId) { - webSocketUrl += `?conversation_id=${chatBody.dataset.conversationId}`; - webSocketUrl += (!!region && !!city && !!countryName) && !!timezone ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : ''; - - websocket = new WebSocket(webSocketUrl); - websocket.onmessage = function(event) { - - // Get the last element in the chat-body - let chunk = event.data; - if (chunk == "start_llm_response") { - console.log("Started streaming", new Date()); - } else if (chunk == "end_llm_response") { - console.log("Stopped streaming", new Date()); - - // Automatically respond with voice if the subscribed user has sent voice message - if (websocketState.isVoice && "{{ is_active }}" == "True") - textToSpeech(websocketState.rawResponse); - - // Append any references after all the data has been streamed - finalizeChatBodyResponse(websocketState.references, websocketState.newResponseTextEl); - - const liveQuery = websocketState.rawQuery; - // Reset variables - websocketState = { - newResponseTextEl: null, - newResponseEl: null, - loadingEllipsis: null, - references: {}, - rawResponse: "", - rawQuery: liveQuery, - isVoice: false, - } - } else { - try { - if (chunk.includes("application/json")) - { - chunk = JSON.parse(chunk); - } - } catch (error) { - // If the chunk is not a JSON object, continue. - } - - const contentType = chunk["content-type"] - - if (contentType === "application/json") { - // Handle JSON response - try { - if (chunk.image || chunk.detail) { - ({rawResponse, references } = handleImageResponse(chunk, websocketState.rawResponse)); - websocketState.rawResponse = rawResponse; - websocketState.references = references; - } else if (chunk.type == "status") { - handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, null, false); - } else if (chunk.type == "rate_limit") { - handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, websocketState.loadingEllipsis, true); - } else { - rawResponse = chunk.response; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - websocketState.rawResponse += chunk; - } finally { - if (chunk.type != "status" && chunk.type != "rate_limit") { - addMessageToChatBody(websocketState.rawResponse, websocketState.newResponseTextEl, websocketState.references); - } - } - } else { - - // Handle streamed response of type text/event-stream or text/plain - if (chunk && chunk.includes("### compiled references:")) { - ({ rawResponse, references } = handleCompiledReferences(websocketState.newResponseTextEl, chunk, websocketState.references, websocketState.rawResponse)); - websocketState.rawResponse = rawResponse; - websocketState.references = references; - } else { - // If the chunk is not a JSON object, just display it as is - websocketState.rawResponse += chunk; - if (websocketState.newResponseTextEl) { - handleStreamResponse(websocketState.newResponseTextEl, websocketState.rawResponse, websocketState.rawQuery, websocketState.loadingEllipsis); - } - } - - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }; - } - } - }; - websocket.onclose = function(event) { - websocket = null; - console.log("WebSocket is closed now."); - let setupWebSocketButton = document.createElement("button"); - setupWebSocketButton.textContent = "Reconnect to Server"; - setupWebSocketButton.onclick = setupWebSocket; - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "red"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.innerHTML = ""; - statusDotText.style.marginTop = "5px"; - statusDotText.appendChild(setupWebSocketButton); - } - websocket.onerror = function(event) { - console.log("WebSocket error observed:", event); - } - - websocket.onopen = function(event) { - console.log("WebSocket is open now.") - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "green"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.textContent = "Connected to Server"; - } } - function sendMessageViaWebSocket(isVoice=false) { - let chatBody = document.getElementById("chat-body"); - - var query = document.getElementById("chat-input").value.trim(); - console.log(`Query: ${query}`); - - if (userMessages.length >= 10) { - userMessages.shift(); - } - userMessages.push(query); - resetUserMessageIndex(); - - // Add message by user to chat body - renderMessage(query, "you"); - document.getElementById("chat-input").value = ""; - autoResize(); - document.getElementById("chat-input").setAttribute("disabled", "disabled"); - - let newResponseEl = document.createElement("div"); - newResponseEl.classList.add("chat-message", "khoj"); - newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); - chatBody.appendChild(newResponseEl); - - let newResponseTextEl = document.createElement("div"); - newResponseTextEl.classList.add("chat-message-text", "khoj"); - newResponseEl.appendChild(newResponseTextEl); - - // Temporary status message to indicate that Khoj is thinking - let loadingEllipsis = createLoadingEllipse(); - - newResponseTextEl.appendChild(loadingEllipsis); - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - - let chatTooltip = document.getElementById("chat-tooltip"); - chatTooltip.style.display = "none"; - - let chatInput = document.getElementById("chat-input"); - chatInput.classList.remove("option-enabled"); - - // Call specified Khoj API - websocket.send(query); - let rawResponse = ""; - let references = {}; - - websocketState = { - newResponseTextEl, - newResponseEl, - loadingEllipsis, - references, - rawResponse, - rawQuery: query, - isVoice: isVoice, - } - } var userMessages = []; var userMessageIndex = -1; function loadChat() { @@ -1265,7 +1155,7 @@ To get started, just start typing below. You can also type / to see a list of co let chatHistoryUrl = `/api/chat/history?client=web`; if (chatBody.dataset.conversationId) { chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`; - setupWebSocket(); + initMessageState(); loadFileFiltersFromConversation(); } @@ -1305,7 +1195,7 @@ To get started, just start typing below. You can also type / to see a list of co let chatBody = document.getElementById("chat-body"); chatBody.dataset.conversationId = response.conversation_id; loadFileFiltersFromConversation(); - setupWebSocket(); + initMessageState(); chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`; let agentMetadata = response.agent; diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 797066d7..ea7368e6 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -62,10 +62,6 @@ class ThreadedGenerator: self.queue.put(data) def close(self): - if self.compiled_references and len(self.compiled_references) > 0: - self.queue.put(f"### compiled references:{json.dumps(self.compiled_references)}") - if self.online_results and len(self.online_results) > 0: - self.queue.put(f"### compiled references:{json.dumps(self.online_results)}") self.queue.put(StopIteration) diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 72191077..c087de70 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -11,6 +11,7 @@ from bs4 import BeautifulSoup from markdownify import markdownify from khoj.routers.helpers import ( + ChatEvent, extract_relevant_info, generate_online_subqueries, infer_webpage_urls, @@ -56,7 +57,8 @@ async def search_online( query += " ".join(custom_filters) if not is_internet_connected(): logger.warn("Cannot search online as not connected to internet") - return {} + yield {} + return # Breakdown the query into subqueries to get the correct answer subqueries = await generate_online_subqueries(query, conversation_history, location) @@ -66,7 +68,8 @@ async def search_online( logger.info(f"🌐 Searching the Internet for {list(subqueries)}") if send_status_func: subqueries_str = "\n- " + "\n- ".join(list(subqueries)) - await send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}") + async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"): + yield {ChatEvent.STATUS: event} with timer(f"Internet searches for {list(subqueries)} took", logger): search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina @@ -89,7 +92,8 @@ async def search_online( logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}") if send_status_func: webpage_links_str = "\n- " + "\n- ".join(list(webpage_links)) - await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}") + async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"): + yield {ChatEvent.STATUS: event} tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages] results = await asyncio.gather(*tasks) @@ -98,7 +102,7 @@ async def search_online( if webpage_extract is not None: response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract} - return response_dict + yield response_dict async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]: @@ -127,13 +131,15 @@ async def read_webpages( "Infer web pages to read from the query and extract relevant information from them" logger.info(f"Inferring web pages to read") if send_status_func: - await send_status_func(f"**🧐 Inferring web pages to read**") + async for event in send_status_func(f"**🧐 Inferring web pages to read**"): + yield {ChatEvent.STATUS: event} urls = await infer_webpage_urls(query, conversation_history, location) logger.info(f"Reading web pages at: {urls}") if send_status_func: webpage_links_str = "\n- " + "\n- ".join(list(urls)) - await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}") + async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"): + yield {ChatEvent.STATUS: event} tasks = [read_webpage_and_extract_content(query, url) for url in urls] results = await asyncio.gather(*tasks) @@ -141,7 +147,7 @@ async def read_webpages( response[query]["webpages"] = [ {"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None ] - return response + yield response async def read_webpage_and_extract_content( diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index c9d76ae7..15d7cbc7 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -6,7 +6,6 @@ import os import threading import time import uuid -from random import random from typing import Any, Callable, List, Optional, Union import cron_descriptor @@ -37,6 +36,7 @@ from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.routers.helpers import ( ApiUserRateLimiter, + ChatEvent, CommonQueryParams, ConversationCommandRateLimiter, acreate_title_from_query, @@ -298,11 +298,13 @@ async def extract_references_and_questions( not ConversationCommand.Notes in conversation_commands and not ConversationCommand.Default in conversation_commands ): - return compiled_references, inferred_queries, q + yield compiled_references, inferred_queries, q + return if not await sync_to_async(EntryAdapters.user_has_entries)(user=user): logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.") - return compiled_references, inferred_queries, q + yield compiled_references, inferred_queries, q + return # Extract filter terms from user message defiltered_query = q @@ -313,7 +315,8 @@ async def extract_references_and_questions( if not conversation: logger.error(f"Conversation with id {conversation_id} not found.") - return compiled_references, inferred_queries, defiltered_query + yield compiled_references, inferred_queries, defiltered_query + return filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters]) using_offline_chat = False @@ -373,7 +376,8 @@ async def extract_references_and_questions( logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}") if send_status_func: inferred_queries_str = "\n- " + "\n- ".join(inferred_queries) - await send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}") + async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"): + yield {ChatEvent.STATUS: event} for query in inferred_queries: n_items = min(n, 3) if using_offline_chat else n search_results.extend( @@ -392,7 +396,7 @@ async def extract_references_and_questions( {"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results ] - return compiled_references, inferred_queries, defiltered_query + yield compiled_references, inferred_queries, defiltered_query @api.get("/health", response_class=Response) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index be28622b..63529b8e 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1,17 +1,17 @@ +import asyncio import json import logging -import math +import time from datetime import datetime +from functools import partial from typing import Any, Dict, List, Optional from urllib.parse import unquote from asgiref.sync import sync_to_async -from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.requests import Request from fastapi.responses import Response, StreamingResponse from starlette.authentication import requires -from starlette.websockets import WebSocketDisconnect -from websockets import ConnectionClosedOK from khoj.app.settings import ALLOWED_HOSTS from khoj.database.adapters import ( @@ -23,19 +23,15 @@ from khoj.database.adapters import ( aget_user_name, ) from khoj.database.models import KhojUser -from khoj.processor.conversation.prompts import ( - help_message, - no_entries_found, - no_notes_found, -) +from khoj.processor.conversation.prompts import help_message, no_entries_found from khoj.processor.conversation.utils import save_to_conversation_log from khoj.processor.speech.text_to_speech import generate_text_to_speech from khoj.processor.tools.online_search import read_webpages, search_online from khoj.routers.api import extract_references_and_questions from khoj.routers.helpers import ( ApiUserRateLimiter, + ChatEvent, CommonQueryParams, - CommonQueryParamsClass, ConversationCommandRateLimiter, agenerate_chat_response, aget_relevant_information_sources, @@ -526,141 +522,142 @@ async def set_conversation_title( ) -@api_chat.websocket("/ws") -async def websocket_endpoint( - websocket: WebSocket, - conversation_id: int, +@api_chat.get("") +async def chat( + request: Request, + common: CommonQueryParams, + q: str, + n: int = 7, + d: float = 0.18, + stream: Optional[bool] = False, + title: Optional[str] = None, + conversation_id: Optional[int] = None, city: Optional[str] = None, region: Optional[str] = None, country: Optional[str] = None, timezone: Optional[str] = None, + rate_limiter_per_minute=Depends( + ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") + ), + rate_limiter_per_day=Depends( + ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") + ), ): - connection_alive = True + async def event_generator(q: str): + start_time = time.perf_counter() + ttft = None + chat_metadata: dict = {} + connection_alive = True + user: KhojUser = request.user.object + event_delimiter = "␃🔚␗" + q = unquote(q) - async def send_status_update(message: str): - nonlocal connection_alive - if not connection_alive: + async def send_event(event_type: ChatEvent, data: str | dict): + nonlocal connection_alive, ttft + if not connection_alive or await request.is_disconnected(): + connection_alive = False + logger.warn(f"User {user} disconnected from {common.client} client") + return + try: + if event_type == ChatEvent.END_LLM_RESPONSE: + collect_telemetry() + if event_type == ChatEvent.START_LLM_RESPONSE: + ttft = time.perf_counter() - start_time + if event_type == ChatEvent.MESSAGE: + yield data + elif event_type == ChatEvent.REFERENCES or stream: + yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) + except asyncio.CancelledError as e: + connection_alive = False + logger.warn(f"User {user} disconnected from {common.client} client: {e}") + return + except Exception as e: + connection_alive = False + logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True) + return + finally: + if stream: + yield event_delimiter + + async def send_llm_response(response: str): + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): + yield result + async for result in send_event(ChatEvent.MESSAGE, response): + yield result + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): + yield result + + def collect_telemetry(): + # Gather chat response telemetry + nonlocal chat_metadata + latency = time.perf_counter() - start_time + cmd_set = set([cmd.value for cmd in conversation_commands]) + chat_metadata = chat_metadata or {} + chat_metadata["conversation_command"] = cmd_set + chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None + chat_metadata["latency"] = f"{latency:.3f}" + chat_metadata["ttft_latency"] = f"{ttft:.3f}" + + logger.info(f"Chat response time to first token: {ttft:.3f} seconds") + logger.info(f"Chat response total time: {latency:.3f} seconds") + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + client=request.user.client_app, + user_agent=request.headers.get("user-agent"), + host=request.headers.get("host"), + metadata=chat_metadata, + ) + + conversation = await ConversationAdapters.aget_conversation_by_user( + user, client_application=request.user.client_app, conversation_id=conversation_id, title=title + ) + if not conversation: + async for result in send_llm_response(f"Conversation {conversation_id} not found"): + yield result return - status_packet = { - "type": "status", - "message": message, - "content-type": "application/json", - } - try: - await websocket.send_text(json.dumps(status_packet)) - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") + await is_ready_to_chat(user) - async def send_complete_llm_response(llm_response: str): - nonlocal connection_alive - if not connection_alive: - return - try: - await websocket.send_text("start_llm_response") - await websocket.send_text(llm_response) - await websocket.send_text("end_llm_response") - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") - - async def send_message(message: str): - nonlocal connection_alive - if not connection_alive: - return - try: - await websocket.send_text(message) - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") - - async def send_rate_limit_message(message: str): - nonlocal connection_alive - if not connection_alive: - return - - status_packet = { - "type": "rate_limit", - "message": message, - "content-type": "application/json", - } - try: - await websocket.send_text(json.dumps(status_packet)) - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") - - user: KhojUser = websocket.user.object - conversation = await ConversationAdapters.aget_conversation_by_user( - user, client_application=websocket.user.client_app, conversation_id=conversation_id - ) - - hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") - - daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") - - await is_ready_to_chat(user) - - user_name = await aget_user_name(user) - - location = None - - if city or region or country: - location = LocationData(city=city, region=region, country=country) - - await websocket.accept() - while connection_alive: - try: - if conversation: - await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"]) - q = await websocket.receive_text() - - # Refresh these because the connection to the database might have been closed - await conversation.arefresh_from_db() - - except WebSocketDisconnect: - logger.debug(f"User {user} disconnected web socket") - break - - try: - await sync_to_async(hourly_limiter)(websocket) - await sync_to_async(daily_limiter)(websocket) - except HTTPException as e: - await send_rate_limit_message(e.detail) - break + user_name = await aget_user_name(user) + location = None + if city or region or country: + location = LocationData(city=city, region=region, country=country) if is_query_empty(q): - await send_message("start_llm_response") - await send_message( - "It seems like your query is incomplete. Could you please provide more details or specify what you need help with?" - ) - await send_message("end_llm_response") - continue + async for result in send_llm_response("Please ask your query to get started."): + yield result + return user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") conversation_commands = [get_conversation_command(query=q, any_references=True)] - await send_status_update(f"**👀 Understanding Query**: {q}") + async for result in send_event(ChatEvent.STATUS, f"**👀 Understanding Query**: {q}"): + yield result meta_log = conversation.conversation_log is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] - used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] if conversation_commands == [ConversationCommand.Default] or is_automated_task: conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) - await send_status_update(f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}") + async for result in send_event( + ChatEvent.STATUS, f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}" + ): + yield result mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) - await send_status_update(f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}") + async for result in send_event(ChatEvent.STATUS, f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"): + yield result if mode not in conversation_commands: conversation_commands.append(mode) for cmd in conversation_commands: - await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd) + await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) q = q.replace(f"/{cmd.value}", "").strip() + used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] file_filters = conversation.file_filters if conversation else [] # Skip trying to summarize if if ( @@ -676,28 +673,37 @@ async def websocket_endpoint( response_log = "" if len(file_filters) == 0: response_log = "No files selected for summarization. Please add files using the section on the left." - await send_complete_llm_response(response_log) + async for result in send_llm_response(response_log): + yield result elif len(file_filters) > 1: response_log = "Only one file can be selected for summarization." - await send_complete_llm_response(response_log) + async for result in send_llm_response(response_log): + yield result else: try: file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) if len(file_object) == 0: response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." - await send_complete_llm_response(response_log) - continue + async for result in send_llm_response(response_log): + yield result + return contextual_data = " ".join([file.raw_text for file in file_object]) if not q: q = "Create a general summary of the file" - await send_status_update(f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}") + async for result in send_event( + ChatEvent.STATUS, f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}" + ): + yield result + response = await extract_relevant_summary(q, contextual_data) response_log = str(response) - await send_complete_llm_response(response_log) + async for result in send_llm_response(response_log): + yield result except Exception as e: response_log = "Error summarizing file." logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) - await send_complete_llm_response(response_log) + async for result in send_llm_response(response_log): + yield result await sync_to_async(save_to_conversation_log)( q, response_log, @@ -705,16 +711,10 @@ async def websocket_endpoint( meta_log, user_message_time, intent_type="summarize", - client_application=websocket.user.client_app, + client_application=request.user.client_app, conversation_id=conversation_id, ) - update_telemetry_state( - request=websocket, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - ) - continue + return custom_filters = [] if conversation_commands == [ConversationCommand.Help]: @@ -724,8 +724,9 @@ async def websocket_endpoint( conversation_config = await ConversationAdapters.aget_default_conversation_config() model_type = conversation_config.model_type formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) - await send_complete_llm_response(formatted_help) - continue + async for result in send_llm_response(formatted_help): + yield result + return # Adding specification to search online specifically on khoj.dev pages. custom_filters.append("site:khoj.dev") conversation_commands.append(ConversationCommand.Online) @@ -733,14 +734,14 @@ async def websocket_endpoint( if ConversationCommand.Automation in conversation_commands: try: automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, websocket.url, meta_log + q, timezone, user, request.url, meta_log ) except Exception as e: logger.error(f"Error scheduling task {q} for {user.email}: {e}") - await send_complete_llm_response( - f"Unable to create automation. Ensure the automation doesn't already exist." - ) - continue + error_message = f"Unable to create automation. Ensure the automation doesn't already exist." + async for result in send_llm_response(error_message): + yield result + return llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) await sync_to_async(save_to_conversation_log)( @@ -750,57 +751,78 @@ async def websocket_endpoint( meta_log, user_message_time, intent_type="automation", - client_application=websocket.user.client_app, + client_application=request.user.client_app, conversation_id=conversation_id, inferred_queries=[query_to_run], automation_id=automation.id, ) - common = CommonQueryParamsClass( - client=websocket.user.client_app, - user_agent=websocket.headers.get("user-agent"), - host=websocket.headers.get("host"), - ) - update_telemetry_state( - request=websocket, - telemetry_type="api", - api="chat", - **common.__dict__, - ) - await send_complete_llm_response(llm_response) - continue + async for result in send_llm_response(llm_response): + yield result + return - compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( - websocket, meta_log, q, 7, 0.18, conversation_id, conversation_commands, location, send_status_update - ) + # Gather Context + ## Extract Document References + compiled_references, inferred_queries, defiltered_query = [], [], None + async for result in extract_references_and_questions( + request, + meta_log, + q, + (n or 7), + (d or 0.18), + conversation_id, + conversation_commands, + location, + partial(send_event, ChatEvent.STATUS), + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] - if compiled_references: + if not is_none_or_empty(compiled_references): headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) - await send_status_update(f"**📜 Found Relevant Notes**: {headings}") + async for result in send_event(ChatEvent.STATUS, f"**📜 Found Relevant Notes**: {headings}"): + yield result online_results: Dict = dict() if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): - await send_complete_llm_response(f"{no_entries_found.format()}") - continue + async for result in send_llm_response(f"{no_entries_found.format()}"): + yield result + return if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): conversation_commands.remove(ConversationCommand.Notes) + ## Gather Online References if ConversationCommand.Online in conversation_commands: try: - online_results = await search_online( - defiltered_query, meta_log, location, send_status_update, custom_filters - ) + async for result in search_online( + defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + online_results = result except ValueError as e: - logger.warning(f"Error searching online: {e}. Attempting to respond without online results") - await send_complete_llm_response( - f"Error searching online: {e}. Attempting to respond without online results" - ) - continue + error_message = f"Error searching online: {e}. Attempting to respond without online results" + logger.warning(error_message) + async for result in send_llm_response(error_message): + yield result + return + ## Gather Webpage References if ConversationCommand.Webpage in conversation_commands: try: - direct_web_pages = await read_webpages(defiltered_query, meta_log, location, send_status_update) + async for result in read_webpages( + defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS) + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + direct_web_pages = result webpages = [] for query in direct_web_pages: if online_results.get(query): @@ -810,38 +832,52 @@ async def websocket_endpoint( for webpage in direct_web_pages[query]["webpages"]: webpages.append(webpage["link"]) - - await send_status_update(f"**📚 Read web pages**: {webpages}") + async for result in send_event(ChatEvent.STATUS, f"**📚 Read web pages**: {webpages}"): + yield result except ValueError as e: logger.warning( - f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True + f"Error directly reading webpages: {e}. Attempting to respond without online results", + exc_info=True, ) + ## Send Gathered References + async for result in send_event( + ChatEvent.REFERENCES, + { + "inferredQueries": inferred_queries, + "context": compiled_references, + "onlineContext": online_results, + }, + ): + yield result + + # Generate Output + ## Generate Image Output if ConversationCommand.Image in conversation_commands: - update_telemetry_state( - request=websocket, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - ) - image, status_code, improved_image_prompt, intent_type = await text_to_image( + async for result in text_to_image( q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results, - send_status_func=send_status_update, - ) + send_status_func=partial(send_event, ChatEvent.STATUS), + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + image, status_code, improved_image_prompt, intent_type = result + if image is None or status_code != 200: content_obj = { - "image": image, + "content-type": "application/json", "intentType": intent_type, "detail": improved_image_prompt, - "content-type": "application/json", + "image": image, } - await send_complete_llm_response(json.dumps(content_obj)) - continue + async for result in send_llm_response(json.dumps(content_obj)): + yield result + return await sync_to_async(save_to_conversation_log)( q, @@ -851,17 +887,23 @@ async def websocket_endpoint( user_message_time, intent_type=intent_type, inferred_queries=[improved_image_prompt], - client_application=websocket.user.client_app, + client_application=request.user.client_app, conversation_id=conversation_id, compiled_references=compiled_references, online_results=online_results, ) - content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore + content_obj = { + "intentType": intent_type, + "inferredQueries": [improved_image_prompt], + "image": image, + } + async for result in send_llm_response(json.dumps(content_obj)): + yield result + return - await send_complete_llm_response(json.dumps(content_obj)) - continue - - await send_status_update(f"**💭 Generating a well-informed response**") + ## Generate Text Output + async for result in send_event(ChatEvent.STATUS, f"**💭 Generating a well-informed response**"): + yield result llm_response, chat_metadata = await agenerate_chat_response( defiltered_query, meta_log, @@ -871,310 +913,49 @@ async def websocket_endpoint( inferred_queries, conversation_commands, user, - websocket.user.client_app, + request.user.client_app, conversation_id, location, user_name, ) - chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None + # Send Response + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): + yield result - update_telemetry_state( - request=websocket, - telemetry_type="api", - api="chat", - metadata=chat_metadata, - ) + continue_stream = True iterator = AsyncIteratorWrapper(llm_response) - - await send_message("start_llm_response") - async for item in iterator: if item is None: - break - if connection_alive: - try: - await send_message(f"{item}") - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") - - await send_message("end_llm_response") - - -@api_chat.get("", response_class=Response) -@requires(["authenticated"]) -async def chat( - request: Request, - common: CommonQueryParams, - q: str, - n: Optional[int] = 5, - d: Optional[float] = 0.22, - stream: Optional[bool] = False, - title: Optional[str] = None, - conversation_id: Optional[int] = None, - city: Optional[str] = None, - region: Optional[str] = None, - country: Optional[str] = None, - timezone: Optional[str] = None, - rate_limiter_per_minute=Depends( - ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") - ), - rate_limiter_per_day=Depends( - ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") - ), -) -> Response: - user: KhojUser = request.user.object - q = unquote(q) - if is_query_empty(q): - return Response( - content="It seems like your query is incomplete. Could you please provide more details or specify what you need help with?", - media_type="text/plain", - status_code=400, - ) - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - logger.info(f"Chat request by {user.username}: {q}") - - await is_ready_to_chat(user) - conversation_commands = [get_conversation_command(query=q, any_references=True)] - - _custom_filters = [] - if conversation_commands == [ConversationCommand.Help]: - help_str = "/" + ConversationCommand.Help - if q.strip() == help_str: - conversation_config = await ConversationAdapters.aget_user_conversation_config(user) - if conversation_config == None: - conversation_config = await ConversationAdapters.aget_default_conversation_config() - model_type = conversation_config.model_type - formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) - return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200) - # Adding specification to search online specifically on khoj.dev pages. - _custom_filters.append("site:khoj.dev") - conversation_commands.append(ConversationCommand.Online) - - conversation = await ConversationAdapters.aget_conversation_by_user( - user, request.user.client_app, conversation_id, title - ) - conversation_id = conversation.id if conversation else None - - if not conversation: - return Response( - content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400 - ) - else: - meta_log = conversation.conversation_log - - if ConversationCommand.Summarize in conversation_commands: - file_filters = conversation.file_filters - llm_response = "" - if len(file_filters) == 0: - llm_response = "No files selected for summarization. Please add files using the section on the left." - elif len(file_filters) > 1: - llm_response = "Only one file can be selected for summarization." - else: + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): + yield result + logger.debug("Finished streaming response") + return + if not connection_alive or not continue_stream: + continue try: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) - if len(file_object) == 0: - llm_response = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." - return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200) - contextual_data = " ".join([file.raw_text for file in file_object]) - summarizeStr = "/" + ConversationCommand.Summarize - if q.strip() == summarizeStr: - q = "Create a general summary of the file" - response = await extract_relevant_summary(q, contextual_data) - llm_response = str(response) + async for result in send_event(ChatEvent.MESSAGE, f"{item}"): + yield result except Exception as e: - logger.error(f"Error summarizing file for {user.email}: {e}") - llm_response = "Error summarizing file." - await sync_to_async(save_to_conversation_log)( - q, - llm_response, - user, - conversation.conversation_log, - user_message_time, - intent_type="summarize", - client_application=request.user.client_app, - conversation_id=conversation_id, - ) - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - **common.__dict__, - ) - return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200) - - is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] - - if conversation_commands == [ConversationCommand.Default] or is_automated_task: - conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) - if mode not in conversation_commands: - conversation_commands.append(mode) - - for cmd in conversation_commands: - await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) - q = q.replace(f"/{cmd.value}", "").strip() - - location = None - - if city or region or country: - location = LocationData(city=city, region=region, country=country) - - user_name = await aget_user_name(user) - - if ConversationCommand.Automation in conversation_commands: - try: - automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, request.url, meta_log - ) - except Exception as e: - logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True) - return Response( - content=f"Unable to create automation. Ensure the automation doesn't already exist.", - media_type="text/plain", - status_code=500, - ) - - llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) - await sync_to_async(save_to_conversation_log)( - q, - llm_response, - user, - meta_log, - user_message_time, - intent_type="automation", - client_application=request.user.client_app, - conversation_id=conversation_id, - inferred_queries=[query_to_run], - automation_id=automation.id, - ) - - if stream: - return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200) - else: - return Response(content=llm_response, media_type="text/plain", status_code=200) - - compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( - request, meta_log, q, (n or 5), (d or math.inf), conversation_id, conversation_commands, location - ) - online_results: Dict[str, Dict] = {} - - if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): - no_entries_found_format = no_entries_found.format() - if stream: - return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200) - else: - response_obj = {"response": no_entries_found_format} - return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200) - - if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references): - no_notes_found_format = no_notes_found.format() - if stream: - return StreamingResponse(iter([no_notes_found_format]), media_type="text/event-stream", status_code=200) - else: - response_obj = {"response": no_notes_found_format} - return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200) - - if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): - conversation_commands.remove(ConversationCommand.Notes) - - if ConversationCommand.Online in conversation_commands: - try: - online_results = await search_online(defiltered_query, meta_log, location, custom_filters=_custom_filters) - except ValueError as e: - logger.warning(f"Error searching online: {e}. Attempting to respond without online results") - - if ConversationCommand.Webpage in conversation_commands: - try: - online_results = await read_webpages(defiltered_query, meta_log, location) - except ValueError as e: - logger.warning( - f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True - ) - - if ConversationCommand.Image in conversation_commands: - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - **common.__dict__, - ) - image, status_code, improved_image_prompt, intent_type = await text_to_image( - q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results - ) - if image is None: - content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt} - return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) - - await sync_to_async(save_to_conversation_log)( - q, - image, - user, - meta_log, - user_message_time, - intent_type=intent_type, - inferred_queries=[improved_image_prompt], - client_application=request.user.client_app, - conversation_id=conversation.id, - compiled_references=compiled_references, - online_results=online_results, - ) - content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore - return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) - - # Get the (streamed) chat response from the LLM of choice. - llm_response, chat_metadata = await agenerate_chat_response( - defiltered_query, - meta_log, - conversation, - compiled_references, - online_results, - inferred_queries, - conversation_commands, - user, - request.user.client_app, - conversation.id, - location, - user_name, - ) - - cmd_set = set([cmd.value for cmd in conversation_commands]) - chat_metadata["conversation_command"] = cmd_set - chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None - - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata=chat_metadata, - **common.__dict__, - ) - - if llm_response is None: - return Response(content=llm_response, media_type="text/plain", status_code=500) + continue_stream = False + logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") + ## Stream Text Response if stream: - return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200) + return StreamingResponse(event_generator(q), media_type="text/plain") + ## Non-Streaming Text Response + else: + # Get the full response from the generator if the stream is not requested. + response_obj = {} + actual_response = "" + iterator = event_generator(q) + async for item in iterator: + try: + item_json = json.loads(item) + if "type" in item_json and item_json["type"] == ChatEvent.REFERENCES.value: + response_obj = item_json["data"] + except: + actual_response += item + response_obj["response"] = actual_response - iterator = AsyncIteratorWrapper(llm_response) - - # Get the full response from the generator if the stream is not requested. - aggregated_gpt_response = "" - async for item in iterator: - if item is None: - break - aggregated_gpt_response += item - - actual_response = aggregated_gpt_response.split("### compiled references:")[0] - - response_obj = { - "response": actual_response, - "inferredQueries": inferred_queries, - "context": compiled_references, - "online_results": online_results, - } - - return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200) + return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 25d21f29..846f5c8f 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -8,6 +8,7 @@ import math import re from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta, timezone +from enum import Enum from functools import partial from random import random from typing import ( @@ -753,7 +754,7 @@ async def text_to_image( references: List[Dict[str, Any]], online_results: Dict[str, Any], send_status_func: Optional[Callable] = None, -) -> Tuple[Optional[str], int, Optional[str], str]: +): status_code = 200 image = None response = None @@ -765,7 +766,8 @@ async def text_to_image( # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 message = "Failed to generate image. Setup image generation on the server." - return image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message, intent_type.value + return text2image_model = text_to_image_config.model_name chat_history = "" @@ -777,20 +779,21 @@ async def text_to_image( chat_history += f"Q: Prompt: {chat['intent']['query']}\n" chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n" - with timer("Improve the original user query", logger): - if send_status_func: - await send_status_func("**✍🏽 Enhancing the Painting Prompt**") - improved_image_prompt = await generate_better_image_prompt( - message, - chat_history, - location_data=location_data, - note_references=references, - online_results=online_results, - model_type=text_to_image_config.model_type, - ) + if send_status_func: + async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"): + yield {ChatEvent.STATUS: event} + improved_image_prompt = await generate_better_image_prompt( + message, + chat_history, + location_data=location_data, + note_references=references, + online_results=online_results, + model_type=text_to_image_config.model_type, + ) if send_status_func: - await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}") + async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"): + yield {ChatEvent.STATUS: event} if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: with timer("Generate image with OpenAI", logger): @@ -815,12 +818,14 @@ async def text_to_image( logger.error(f"Image Generation blocked by OpenAI: {e}") status_code = e.status_code # type: ignore message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore - return image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message, intent_type.value + return else: logger.error(f"Image Generation failed with {e}", exc_info=True) message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore status_code = e.status_code # type: ignore - return image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message, intent_type.value + return elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI: with timer("Generate image with Stability AI", logger): @@ -842,7 +847,8 @@ async def text_to_image( logger.error(f"Image Generation failed with {e}", exc_info=True) message = f"Image generation failed with Stability AI error: {e}" status_code = e.status_code # type: ignore - return image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message, intent_type.value + return with timer("Convert image to webp", logger): # Convert png to webp for faster loading @@ -862,7 +868,7 @@ async def text_to_image( intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 image = base64.b64encode(webp_image_bytes).decode("utf-8") - return image_url or image, status_code, improved_image_prompt, intent_type.value + yield image_url or image, status_code, improved_image_prompt, intent_type.value class ApiUserRateLimiter: @@ -1184,3 +1190,11 @@ def construct_automation_created_message(automation: Job, crontime: str, query_t Manage your automations [here](/automations). """.strip() + + +class ChatEvent(Enum): + START_LLM_RESPONSE = "start_llm_response" + END_LLM_RESPONSE = "end_llm_response" + MESSAGE = "message" + REFERENCES = "references" + STATUS = "status" diff --git a/src/khoj/utils/fs_syncer.py b/src/khoj/utils/fs_syncer.py index 5a20f418..3177d7ee 100644 --- a/src/khoj/utils/fs_syncer.py +++ b/src/khoj/utils/fs_syncer.py @@ -22,7 +22,7 @@ magika = Magika() def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict: - files = {} + files: dict[str, dict] = {"docx": {}, "image": {}} if search_type == SearchType.All or search_type == SearchType.Org: org_config = LocalOrgConfig.objects.filter(user=user).first() diff --git a/tests/test_client.py b/tests/test_client.py index 24d2dff6..c4246a78 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -455,13 +455,13 @@ def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiU @pytest.mark.skipif(os.getenv("OPENAI_API_KEY") is None, reason="requires OPENAI_API_KEY") @pytest.mark.django_db(transaction=True) -def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser): +async def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser): # Arrange headers = {"Authorization": f"Bearer {api_user2.token}"} # Act - auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true', headers=headers) - no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true') + auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"', headers=headers) + no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"') # Assert assert auth_response.status_code == 200 diff --git a/tests/test_offline_chat_director.py b/tests/test_offline_chat_director.py index 43e254e6..c0532c4b 100644 --- a/tests/test_offline_chat_director.py +++ b/tests/test_offline_chat_director.py @@ -67,10 +67,8 @@ def test_chat_with_online_content(client_offline_chat): # Act q = "/online give me the link to paul graham's essay how to do great work" encoded_q = quote(q, safe="") - response = client_offline_chat.get(f"/api/chat?q={encoded_q}&stream=true") - response_message = response.content.decode("utf-8") - - response_message = response_message.split("### compiled references")[0] + response = client_offline_chat.get(f"/api/chat?q={encoded_q}") + response_message = response.json()["response"] # Assert expected_responses = [ @@ -91,10 +89,8 @@ def test_chat_with_online_webpage_content(client_offline_chat): # Act q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?" encoded_q = quote(q, safe="") - response = client_offline_chat.get(f"/api/chat?q={encoded_q}&stream=true") - response_message = response.content.decode("utf-8") - - response_message = response_message.split("### compiled references")[0] + response = client_offline_chat.get(f"/api/chat?q={encoded_q}") + response_message = response.json()["response"] # Assert expected_responses = ["185", "1871", "horse"] diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index 26d93d31..7a05a3dd 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -49,8 +49,8 @@ def create_conversation(message_list, user, agent=None): @pytest.mark.django_db(transaction=True) def test_chat_with_no_chat_history_or_retrieved_content(chat_client): # Act - response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true') - response_message = response.content.decode("utf-8") + response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"') + response_message = response.json()["response"] # Assert expected_responses = ["Khoj", "khoj"] @@ -67,10 +67,8 @@ def test_chat_with_online_content(chat_client): # Act q = "/online give me the link to paul graham's essay how to do great work" encoded_q = quote(q, safe="") - response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true") - response_message = response.content.decode("utf-8") - - response_message = response_message.split("### compiled references")[0] + response = chat_client.get(f"/api/chat?q={encoded_q}") + response_message = response.json()["response"] # Assert expected_responses = [ @@ -91,10 +89,8 @@ def test_chat_with_online_webpage_content(chat_client): # Act q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?" encoded_q = quote(q, safe="") - response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true") - response_message = response.content.decode("utf-8") - - response_message = response_message.split("### compiled references")[0] + response = chat_client.get(f"/api/chat?q={encoded_q}") + response_message = response.json()["response"] # Assert expected_responses = ["185", "1871", "horse"] @@ -144,7 +140,7 @@ def test_answer_from_currently_retrieved_content(chat_client, default_user2: Kho # Act response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"') - response_message = response.content.decode("utf-8") + response_message = response.json()["response"] # Assert assert response.status_code == 200 @@ -168,7 +164,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_n # Act response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"') - response_message = response.content.decode("utf-8") + response_message = response.json()["response"] # Assert assert response.status_code == 200 @@ -191,7 +187,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, d # Act response = chat_client.get(f'/api/chat?q="Where was I born?"') - response_message = response.content.decode("utf-8") + response_message = response.json()["response"] # Assert assert response.status_code == 200 @@ -215,8 +211,8 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use create_conversation(message_list, default_user2) # Act - response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true') - response_message = response.content.decode("utf-8") + response = chat_client.get(f'/api/chat?q="Where was I born?"') + response_message = response.json()["response"] # Assert expected_responses = [ @@ -226,6 +222,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use "do not have", "don't have", "where were you born?", + "where you were born?", ] assert response.status_code == 200 @@ -280,8 +277,8 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default create_conversation(message_list, default_user2) # Act - response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true") - response_message = response.content.decode("utf-8") + response = chat_client_no_background.get(f"/api/chat?q={query}") + response_message = response.json()["response"] # Assert assert response.status_code == 200 @@ -527,8 +524,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c create_conversation(message_list, default_user2) # Act - response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else."&stream=true') - response_message = response.content.decode("utf-8").split("### compiled references")[0] + response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else.') + response_message = response.json()["response"] # Assert expected_responses = ["test", "Test"] @@ -544,9 +541,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c @pytest.mark.chatquality def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background): # Act - - response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son?"&stream=true') - response_message = response.content.decode("utf-8").split("### compiled references")[0].lower() + response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son?"') + response_message = response.json()["response"].lower() # Assert expected_responses = [ @@ -658,8 +654,8 @@ def test_answer_in_chat_history_by_conversation_id_with_agent( def test_answer_requires_multiple_independent_searches(chat_client): "Chat director should be able to answer by doing multiple independent searches for required information" # Act - response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"&stream=true') - response_message = response.content.decode("utf-8").split("### compiled references")[0].lower() + response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"') + response_message = response.json()["response"].lower() # Assert expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"] @@ -683,8 +679,8 @@ def test_answer_using_file_filter(chat_client): 'Is Xi older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"' ) - response = chat_client.get(f"/api/chat?q={query}&stream=true") - response_message = response.content.decode("utf-8").split("### compiled references")[0].lower() + response = chat_client.get(f"/api/chat?q={query}") + response_message = response.json()["response"].lower() # Assert expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]