diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html
index 383fc536..57657ef1 100644
--- a/src/interface/desktop/chat.html
+++ b/src/interface/desktop/chat.html
@@ -61,6 +61,14 @@
let city = null;
let countryName = null;
let timezone = null;
+ let chatMessageState = {
+ newResponseTextEl: null,
+ newResponseEl: null,
+ loadingEllipsis: null,
+ references: {},
+ rawResponse: "",
+ isVoice: false,
+ }
fetch("https://ipapi.co/json")
.then(response => response.json())
@@ -75,10 +83,9 @@
return;
});
- async function chat() {
- // Extract required fields for search from form
+ async function chat(isVoice=false) {
+ // Extract chat message from chat input form
let query = document.getElementById("chat-input").value.trim();
- let resultsCount = localStorage.getItem("khojResultsCount") || 5;
console.log(`Query: ${query}`);
// Short circuit on empty query
@@ -106,9 +113,6 @@
await refreshChatSessionsPanel();
}
- // Generate backend API URL to execute query
- let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`;
-
let newResponseEl = document.createElement("div");
newResponseEl.classList.add("chat-message", "khoj");
newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
@@ -119,25 +123,7 @@
newResponseEl.appendChild(newResponseTextEl);
// Temporary status message to indicate that Khoj is thinking
- let loadingEllipsis = document.createElement("div");
- loadingEllipsis.classList.add("lds-ellipsis");
-
- let firstEllipsis = document.createElement("div");
- firstEllipsis.classList.add("lds-ellipsis-item");
-
- let secondEllipsis = document.createElement("div");
- secondEllipsis.classList.add("lds-ellipsis-item");
-
- let thirdEllipsis = document.createElement("div");
- thirdEllipsis.classList.add("lds-ellipsis-item");
-
- let fourthEllipsis = document.createElement("div");
- fourthEllipsis.classList.add("lds-ellipsis-item");
-
- loadingEllipsis.appendChild(firstEllipsis);
- loadingEllipsis.appendChild(secondEllipsis);
- loadingEllipsis.appendChild(thirdEllipsis);
- loadingEllipsis.appendChild(fourthEllipsis);
+ let loadingEllipsis = createLoadingEllipsis();
newResponseTextEl.appendChild(loadingEllipsis);
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
@@ -148,107 +134,36 @@
let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled");
+ // Setup chat message state
+ chatMessageState = {
+ newResponseTextEl,
+ newResponseEl,
+ loadingEllipsis,
+ references: {},
+ rawResponse: "",
+ rawQuery: query,
+ isVoice: isVoice,
+ }
+
// Call Khoj chat API
- let response = await fetch(chatApi, { headers });
- let rawResponse = "";
- let references = null;
- const contentType = response.headers.get("content-type");
+ let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationID}&stream=true&client=desktop`;
+ chatApi += (!!region && !!city && !!countryName && !!timezone)
+ ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`
+ : '';
- if (contentType === "application/json") {
- // Handle JSON response
- try {
- const responseAsJson = await response.json();
- if (responseAsJson.image) {
- // If response has image field, response is a generated image.
- if (responseAsJson.intentType === "text-to-image") {
- rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
- } else if (responseAsJson.intentType === "text-to-image2") {
- rawResponse += `![${query}](${responseAsJson.image})`;
- } else if (responseAsJson.intentType === "text-to-image-v3") {
- rawResponse += `![${query}](data:image/webp;base64,${responseAsJson.image})`;
- }
- const inferredQueries = responseAsJson.inferredQueries?.[0];
- if (inferredQueries) {
- rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`;
- }
- }
- if (responseAsJson.context) {
- const rawReferenceAsJson = responseAsJson.context;
- references = createReferenceSection(rawReferenceAsJson);
- }
- if (responseAsJson.detail) {
- // If response has detail field, response is an error message.
- rawResponse += responseAsJson.detail;
- }
- } catch (error) {
- // If the chunk is not a JSON object, just display it as is
- rawResponse += chunk;
- } finally {
- newResponseTextEl.innerHTML = "";
- newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
+ const response = await fetch(chatApi, { headers });
- if (references != null) {
- newResponseTextEl.appendChild(references);
- }
-
- document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
- document.getElementById("chat-input").removeAttribute("disabled");
- }
- } else {
- // Handle streamed response of type text/event-stream or text/plain
- const reader = response.body.getReader();
- const decoder = new TextDecoder();
- let references = {};
-
- readStream();
-
- function readStream() {
- reader.read().then(({ done, value }) => {
- if (done) {
- // Append any references after all the data has been streamed
- if (references != {}) {
- newResponseTextEl.appendChild(createReferenceSection(references));
- }
- document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
- document.getElementById("chat-input").removeAttribute("disabled");
- return;
- }
-
- // Decode message chunk from stream
- const chunk = decoder.decode(value, { stream: true });
-
- if (chunk.includes("### compiled references:")) {
- const additionalResponse = chunk.split("### compiled references:")[0];
- rawResponse += additionalResponse;
- newResponseTextEl.innerHTML = "";
- newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
-
- const rawReference = chunk.split("### compiled references:")[1];
- const rawReferenceAsJson = JSON.parse(rawReference);
- if (rawReferenceAsJson instanceof Array) {
- references["notes"] = rawReferenceAsJson;
- } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
- references["online"] = rawReferenceAsJson;
- }
- readStream();
- } else {
- // Display response from Khoj
- if (newResponseTextEl.getElementsByClassName("lds-ellipsis").length > 0) {
- newResponseTextEl.removeChild(loadingEllipsis);
- }
-
- // If the chunk is not a JSON object, just display it as is
- rawResponse += chunk;
- newResponseTextEl.innerHTML = "";
- newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
-
- readStream();
- }
-
- // Scroll to bottom of chat window as chat response is streamed
- document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
- });
- }
+ try {
+ if (!response.ok) throw new Error(response.statusText);
+ if (!response.body) throw new Error("Response body is empty");
+ // Stream and render chat response
+ await readChatStream(response);
+ } catch (err) {
+ console.error(`Khoj chat response failed with\n${err}`);
+ if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis)
+ chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
+ let errorMsg = "Sorry, unable to get response from Khoj backend ❤️🩹. Retry or contact developers for help at team@khoj.dev or on Discord";
+ newResponseTextEl.textContent = errorMsg;
}
}
diff --git a/src/interface/desktop/chatutils.js b/src/interface/desktop/chatutils.js
index 42cfa986..5213979f 100644
--- a/src/interface/desktop/chatutils.js
+++ b/src/interface/desktop/chatutils.js
@@ -364,3 +364,194 @@ function createReferenceSection(references, createLinkerSection=false) {
return referencesDiv;
}
+
+function createLoadingEllipsis() {
+ let loadingEllipsis = document.createElement("div");
+ loadingEllipsis.classList.add("lds-ellipsis");
+
+ let firstEllipsis = document.createElement("div");
+ firstEllipsis.classList.add("lds-ellipsis-item");
+
+ let secondEllipsis = document.createElement("div");
+ secondEllipsis.classList.add("lds-ellipsis-item");
+
+ let thirdEllipsis = document.createElement("div");
+ thirdEllipsis.classList.add("lds-ellipsis-item");
+
+ let fourthEllipsis = document.createElement("div");
+ fourthEllipsis.classList.add("lds-ellipsis-item");
+
+ loadingEllipsis.appendChild(firstEllipsis);
+ loadingEllipsis.appendChild(secondEllipsis);
+ loadingEllipsis.appendChild(thirdEllipsis);
+ loadingEllipsis.appendChild(fourthEllipsis);
+
+ return loadingEllipsis;
+}
+
+function handleStreamResponse(newResponseElement, rawResponse, rawQuery, loadingEllipsis, replace=true) {
+ if (!newResponseElement) return;
+ // Remove loading ellipsis if it exists
+ if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis)
+ newResponseElement.removeChild(loadingEllipsis);
+ // Clear the response element if replace is true
+ if (replace) newResponseElement.innerHTML = "";
+
+ // Append response to the response element
+ newResponseElement.appendChild(formatHTMLMessage(rawResponse, false, replace, rawQuery));
+
+ // Append loading ellipsis if it exists
+ if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis);
+ // Scroll to bottom of chat view
+ document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
+}
+
+function handleImageResponse(imageJson, rawResponse) {
+ if (imageJson.image) {
+ const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
+
+ // If response has image field, response is a generated image.
+ if (imageJson.intentType === "text-to-image") {
+ rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`;
+ } else if (imageJson.intentType === "text-to-image2") {
+ rawResponse += `![generated_image](${imageJson.image})`;
+ } else if (imageJson.intentType === "text-to-image-v3") {
+ rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
+ }
+ if (inferredQuery) {
+ rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
+ }
+ }
+
+ // If response has detail field, response is an error message.
+ if (imageJson.detail) rawResponse += imageJson.detail;
+
+ return rawResponse;
+}
+
+function finalizeChatBodyResponse(references, newResponseElement) {
+ if (!!newResponseElement && references != null && Object.keys(references).length > 0) {
+ newResponseElement.appendChild(createReferenceSection(references));
+ }
+ document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
+ document.getElementById("chat-input")?.removeAttribute("disabled");
+}
+
+function convertMessageChunkToJson(rawChunk) {
+ // Split the chunk into lines
+ if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) {
+ try {
+ let jsonChunk = JSON.parse(rawChunk);
+ if (!jsonChunk.type)
+ jsonChunk = {type: 'message', data: jsonChunk};
+ return jsonChunk;
+ } catch (e) {
+ return {type: 'message', data: rawChunk};
+ }
+ } else if (rawChunk.length > 0) {
+ return {type: 'message', data: rawChunk};
+ }
+}
+
+function processMessageChunk(rawChunk) {
+ const chunk = convertMessageChunkToJson(rawChunk);
+ console.debug("Chunk:", chunk);
+ if (!chunk || !chunk.type) return;
+ if (chunk.type ==='status') {
+ console.log(`status: ${chunk.data}`);
+ const statusMessage = chunk.data;
+ handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, false);
+ } else if (chunk.type === 'start_llm_response') {
+ console.log("Started streaming", new Date());
+ } else if (chunk.type === 'end_llm_response') {
+ console.log("Stopped streaming", new Date());
+
+ // Automatically respond with voice if the subscribed user has sent voice message
+ if (chatMessageState.isVoice && "{{ is_active }}" == "True")
+ textToSpeech(chatMessageState.rawResponse);
+
+ // Append any references after all the data has been streamed
+ finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl);
+
+ const liveQuery = chatMessageState.rawQuery;
+ // Reset variables
+ chatMessageState = {
+ newResponseTextEl: null,
+ newResponseEl: null,
+ loadingEllipsis: null,
+ references: {},
+ rawResponse: "",
+ rawQuery: liveQuery,
+ isVoice: false,
+ }
+ } else if (chunk.type === "references") {
+ chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext};
+ } else if (chunk.type === 'message') {
+ const chunkData = chunk.data;
+ if (typeof chunkData === 'object' && chunkData !== null) {
+ // If chunkData is already a JSON object
+ handleJsonResponse(chunkData);
+ } else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) {
+ // Try process chunk data as if it is a JSON object
+ try {
+ const jsonData = JSON.parse(chunkData.trim());
+ handleJsonResponse(jsonData);
+ } catch (e) {
+ chatMessageState.rawResponse += chunkData;
+ handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
+ }
+ } else {
+ chatMessageState.rawResponse += chunkData;
+ handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
+ }
+ }
+}
+
+function handleJsonResponse(jsonData) {
+ if (jsonData.image || jsonData.detail) {
+ chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse);
+ } else if (jsonData.response) {
+ chatMessageState.rawResponse = jsonData.response;
+ }
+
+ if (chatMessageState.newResponseTextEl) {
+ chatMessageState.newResponseTextEl.innerHTML = "";
+ chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse));
+ }
+}
+
+async function readChatStream(response) {
+ if (!response.body) return;
+ const reader = response.body.getReader();
+ const decoder = new TextDecoder();
+ const eventDelimiter = '␃🔚␗';
+ let buffer = '';
+
+ while (true) {
+ const { value, done } = await reader.read();
+ // If the stream is done
+ if (done) {
+ // Process the last chunk
+ processMessageChunk(buffer);
+ buffer = '';
+ break;
+ }
+
+ // Read chunk from stream and append it to the buffer
+ const chunk = decoder.decode(value, { stream: true });
+ console.debug("Raw Chunk:", chunk)
+ // Start buffering chunks until complete event is received
+ buffer += chunk;
+
+ // Once the buffer contains a complete event
+ let newEventIndex;
+ while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
+ // Extract the event from the buffer
+ const event = buffer.slice(0, newEventIndex);
+ buffer = buffer.slice(newEventIndex + eventDelimiter.length);
+
+ // Process the event
+ if (event) processMessageChunk(event);
+ }
+ }
+}
diff --git a/src/interface/desktop/shortcut.html b/src/interface/desktop/shortcut.html
index 4af26f0d..52207f20 100644
--- a/src/interface/desktop/shortcut.html
+++ b/src/interface/desktop/shortcut.html
@@ -346,7 +346,7 @@
inp.focus();
}
- async function chat() {
+ async function chat(isVoice=false) {
//set chat body to empty
let chatBody = document.getElementById("chat-body");
chatBody.innerHTML = "";
@@ -375,9 +375,6 @@
chat_body.dataset.conversationId = conversationID;
}
- // Generate backend API URL to execute query
- let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`;
-
let newResponseEl = document.createElement("div");
newResponseEl.classList.add("chat-message", "khoj");
newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
@@ -388,128 +385,41 @@
newResponseEl.appendChild(newResponseTextEl);
// Temporary status message to indicate that Khoj is thinking
- let loadingEllipsis = document.createElement("div");
- loadingEllipsis.classList.add("lds-ellipsis");
-
- let firstEllipsis = document.createElement("div");
- firstEllipsis.classList.add("lds-ellipsis-item");
-
- let secondEllipsis = document.createElement("div");
- secondEllipsis.classList.add("lds-ellipsis-item");
-
- let thirdEllipsis = document.createElement("div");
- thirdEllipsis.classList.add("lds-ellipsis-item");
-
- let fourthEllipsis = document.createElement("div");
- fourthEllipsis.classList.add("lds-ellipsis-item");
-
- loadingEllipsis.appendChild(firstEllipsis);
- loadingEllipsis.appendChild(secondEllipsis);
- loadingEllipsis.appendChild(thirdEllipsis);
- loadingEllipsis.appendChild(fourthEllipsis);
-
- newResponseTextEl.appendChild(loadingEllipsis);
+ let loadingEllipsis = createLoadingEllipsis();
document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
- // Call Khoj chat API
- let response = await fetch(chatApi, { headers });
- let rawResponse = "";
- let references = null;
- const contentType = response.headers.get("content-type");
toggleLoading();
- if (contentType === "application/json") {
- // Handle JSON response
- try {
- const responseAsJson = await response.json();
- if (responseAsJson.image) {
- // If response has image field, response is a generated image.
- if (responseAsJson.intentType === "text-to-image") {
- rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
- } else if (responseAsJson.intentType === "text-to-image2") {
- rawResponse += `![${query}](${responseAsJson.image})`;
- } else if (responseAsJson.intentType === "text-to-image-v3") {
- rawResponse += `![${query}](data:image/webp;base64,${responseAsJson.image})`;
- }
- const inferredQueries = responseAsJson.inferredQueries?.[0];
- if (inferredQueries) {
- rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`;
- }
- }
- if (responseAsJson.context) {
- const rawReferenceAsJson = responseAsJson.context;
- references = createReferenceSection(rawReferenceAsJson, createLinkerSection=true);
- }
- if (responseAsJson.detail) {
- // If response has detail field, response is an error message.
- rawResponse += responseAsJson.detail;
- }
- } catch (error) {
- // If the chunk is not a JSON object, just display it as is
- rawResponse += chunk;
- } finally {
- newResponseTextEl.innerHTML = "";
- newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
- if (references != null) {
- newResponseTextEl.appendChild(references);
- }
+ // Setup chat message state
+ chatMessageState = {
+ newResponseTextEl,
+ newResponseEl,
+ loadingEllipsis,
+ references: {},
+ rawResponse: "",
+ rawQuery: query,
+ isVoice: isVoice,
+ }
- document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
- }
- } else {
- // Handle streamed response of type text/event-stream or text/plain
- const reader = response.body.getReader();
- const decoder = new TextDecoder();
- let references = {};
+ // Construct API URL to execute chat query
+ let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationID}&stream=true&client=desktop`;
+ chatApi += (!!region && !!city && !!countryName && !!timezone)
+ ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`
+ : '';
- readStream();
+ const response = await fetch(chatApi, { headers });
- function readStream() {
- reader.read().then(({ done, value }) => {
- if (done) {
- // Append any references after all the data has been streamed
- if (references != {}) {
- newResponseTextEl.appendChild(createReferenceSection(references, createLinkerSection=true));
- }
- document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
- return;
- }
-
- // Decode message chunk from stream
- const chunk = decoder.decode(value, { stream: true });
-
- if (chunk.includes("### compiled references:")) {
- const additionalResponse = chunk.split("### compiled references:")[0];
- rawResponse += additionalResponse;
- newResponseTextEl.innerHTML = "";
- newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
-
- const rawReference = chunk.split("### compiled references:")[1];
- const rawReferenceAsJson = JSON.parse(rawReference);
- if (rawReferenceAsJson instanceof Array) {
- references["notes"] = rawReferenceAsJson;
- } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
- references["online"] = rawReferenceAsJson;
- }
- readStream();
- } else {
- // Display response from Khoj
- if (newResponseTextEl.getElementsByClassName("lds-ellipsis").length > 0) {
- newResponseTextEl.removeChild(loadingEllipsis);
- }
-
- // If the chunk is not a JSON object, just display it as is
- rawResponse += chunk;
- newResponseTextEl.innerHTML = "";
- newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
-
- readStream();
- }
-
- // Scroll to bottom of chat window as chat response is streamed
- document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
- });
- }
+ try {
+ if (!response.ok) throw new Error(response.statusText);
+ if (!response.body) throw new Error("Response body is empty");
+ // Stream and render chat response
+ await readChatStream(response);
+ } catch (err) {
+ console.error(`Khoj chat response failed with\n${err}`);
+ if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis)
+ chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
+ let errorMsg = "Sorry, unable to get response from Khoj backend ❤️🩹. Retry or contact developers for help at team@khoj.dev or on Discord";
+ newResponseTextEl.textContent = errorMsg;
}
document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
}
diff --git a/src/interface/desktop/utils.js b/src/interface/desktop/utils.js
index c880a7cd..af0234ea 100644
--- a/src/interface/desktop/utils.js
+++ b/src/interface/desktop/utils.js
@@ -34,8 +34,8 @@ function toggleNavMenu() {
document.addEventListener('click', function(event) {
let menu = document.getElementById("khoj-nav-menu");
let menuContainer = document.getElementById("khoj-nav-menu-container");
- let isClickOnMenu = menuContainer.contains(event.target) || menuContainer === event.target;
- if (isClickOnMenu === false && menu.classList.contains("show")) {
+ let isClickOnMenu = menuContainer?.contains(event.target) || menuContainer === event.target;
+ if (menu && isClickOnMenu === false && menu.classList.contains("show")) {
menu.classList.remove("show");
}
});
diff --git a/src/interface/obsidian/src/chat_view.ts b/src/interface/obsidian/src/chat_view.ts
index b8d95d6b..cbd0f7bf 100644
--- a/src/interface/obsidian/src/chat_view.ts
+++ b/src/interface/obsidian/src/chat_view.ts
@@ -12,6 +12,25 @@ export interface ChatJsonResult {
inferredQueries?: string[];
}
+interface ChunkResult {
+ objects: string[];
+ remainder: string;
+}
+
+interface MessageChunk {
+ type: string;
+ data: any;
+}
+
+interface ChatMessageState {
+ newResponseTextEl: HTMLElement | null;
+ newResponseEl: HTMLElement | null;
+ loadingEllipsis: HTMLElement | null;
+ references: any;
+ rawResponse: string;
+ rawQuery: string;
+ isVoice: boolean;
+}
interface Location {
region: string;
@@ -26,6 +45,7 @@ export class KhojChatView extends KhojPaneView {
waitingForLocation: boolean;
location: Location;
keyPressTimeout: NodeJS.Timeout | null = null;
+ chatMessageState: ChatMessageState;
constructor(leaf: WorkspaceLeaf, setting: KhojSetting) {
super(leaf, setting);
@@ -409,16 +429,15 @@ export class KhojChatView extends KhojPaneView {
message = DOMPurify.sanitize(message);
// Convert the message to html, sanitize the message html and render it to the real DOM
- let chat_message_body_text_el = this.contentEl.createDiv();
- chat_message_body_text_el.className = "chat-message-text-response";
- chat_message_body_text_el.innerHTML = this.markdownTextToSanitizedHtml(message, this);
+ let chatMessageBodyTextEl = this.contentEl.createDiv();
+ chatMessageBodyTextEl.innerHTML = this.markdownTextToSanitizedHtml(message, this);
// Add a copy button to each chat message, if it doesn't already exist
if (willReplace === true) {
- this.renderActionButtons(message, chat_message_body_text_el);
+ this.renderActionButtons(message, chatMessageBodyTextEl);
}
- return chat_message_body_text_el;
+ return chatMessageBodyTextEl;
}
markdownTextToSanitizedHtml(markdownText: string, component: ItemView): string {
@@ -502,23 +521,23 @@ export class KhojChatView extends KhojPaneView {
class: `khoj-chat-message ${sender}`
},
})
- let chat_message_body_el = chatMessageEl.createDiv();
- chat_message_body_el.addClasses(["khoj-chat-message-text", sender]);
- let chat_message_body_text_el = chat_message_body_el.createDiv();
+ let chatMessageBodyEl = chatMessageEl.createDiv();
+ chatMessageBodyEl.addClasses(["khoj-chat-message-text", sender]);
+ let chatMessageBodyTextEl = chatMessageBodyEl.createDiv();
// Sanitize the markdown to render
message = DOMPurify.sanitize(message);
if (raw) {
- chat_message_body_text_el.innerHTML = message;
+ chatMessageBodyTextEl.innerHTML = message;
} else {
// @ts-ignore
- chat_message_body_text_el.innerHTML = this.markdownTextToSanitizedHtml(message, this);
+ chatMessageBodyTextEl.innerHTML = this.markdownTextToSanitizedHtml(message, this);
}
// Add action buttons to each chat message element
if (willReplace === true) {
- this.renderActionButtons(message, chat_message_body_text_el);
+ this.renderActionButtons(message, chatMessageBodyTextEl);
}
// Remove user-select: none property to make text selectable
@@ -531,42 +550,38 @@ export class KhojChatView extends KhojPaneView {
}
createKhojResponseDiv(dt?: Date): HTMLDivElement {
- let message_time = this.formatDate(dt ?? new Date());
+ let messageTime = this.formatDate(dt ?? new Date());
// Append message to conversation history HTML element.
// The chat logs should display above the message input box to follow standard UI semantics
- let chat_body_el = this.contentEl.getElementsByClassName("khoj-chat-body")[0];
- let chat_message_el = chat_body_el.createDiv({
+ let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0];
+ let chatMessageEl = chatBodyEl.createDiv({
attr: {
- "data-meta": `🏮 Khoj at ${message_time}`,
+ "data-meta": `🏮 Khoj at ${messageTime}`,
class: `khoj-chat-message khoj`
},
- }).createDiv({
- attr: {
- class: `khoj-chat-message-text khoj`
- },
- }).createDiv();
+ })
// Scroll to bottom after inserting chat messages
this.scrollChatToBottom();
- return chat_message_el;
+ return chatMessageEl;
}
async renderIncrementalMessage(htmlElement: HTMLDivElement, additionalMessage: string) {
- this.result += additionalMessage;
+ this.chatMessageState.rawResponse += additionalMessage;
htmlElement.innerHTML = "";
// Sanitize the markdown to render
- this.result = DOMPurify.sanitize(this.result);
+ this.chatMessageState.rawResponse = DOMPurify.sanitize(this.chatMessageState.rawResponse);
// @ts-ignore
- htmlElement.innerHTML = this.markdownTextToSanitizedHtml(this.result, this);
+ htmlElement.innerHTML = this.markdownTextToSanitizedHtml(this.chatMessageState.rawResponse, this);
// Render action buttons for the message
- this.renderActionButtons(this.result, htmlElement);
+ this.renderActionButtons(this.chatMessageState.rawResponse, htmlElement);
// Scroll to bottom of modal, till the send message input box
this.scrollChatToBottom();
}
- renderActionButtons(message: string, chat_message_body_text_el: HTMLElement) {
+ renderActionButtons(message: string, chatMessageBodyTextEl: HTMLElement) {
let copyButton = this.contentEl.createEl('button');
copyButton.classList.add("chat-action-button");
copyButton.title = "Copy Message to Clipboard";
@@ -593,10 +608,10 @@ export class KhojChatView extends KhojPaneView {
}
// Append buttons to parent element
- chat_message_body_text_el.append(copyButton, pasteToFile);
+ chatMessageBodyTextEl.append(copyButton, pasteToFile);
if (speechButton) {
- chat_message_body_text_el.append(speechButton);
+ chatMessageBodyTextEl.append(speechButton);
}
}
@@ -854,35 +869,122 @@ export class KhojChatView extends KhojPaneView {
return true;
}
- async readChatStream(response: Response, responseElement: HTMLDivElement, isVoice: boolean = false): Promise {
+ convertMessageChunkToJson(rawChunk: string): MessageChunk {
+ if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) {
+ try {
+ let jsonChunk = JSON.parse(rawChunk);
+ if (!jsonChunk.type)
+ jsonChunk = {type: 'message', data: jsonChunk};
+ return jsonChunk;
+ } catch (e) {
+ return {type: 'message', data: rawChunk};
+ }
+ } else if (rawChunk.length > 0) {
+ return {type: 'message', data: rawChunk};
+ }
+ return {type: '', data: ''};
+ }
+
+ processMessageChunk(rawChunk: string): void {
+ const chunk = this.convertMessageChunkToJson(rawChunk);
+ console.debug("Chunk:", chunk);
+ if (!chunk || !chunk.type) return;
+ if (chunk.type === 'status') {
+ console.log(`status: ${chunk.data}`);
+ const statusMessage = chunk.data;
+ this.handleStreamResponse(this.chatMessageState.newResponseTextEl, statusMessage, this.chatMessageState.loadingEllipsis, false);
+ } else if (chunk.type === 'start_llm_response') {
+ console.log("Started streaming", new Date());
+ } else if (chunk.type === 'end_llm_response') {
+ console.log("Stopped streaming", new Date());
+
+ // Automatically respond with voice if the subscribed user has sent voice message
+ if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active)
+ this.textToSpeech(this.chatMessageState.rawResponse);
+
+ // Append any references after all the data has been streamed
+ this.finalizeChatBodyResponse(this.chatMessageState.references, this.chatMessageState.newResponseTextEl);
+
+ const liveQuery = this.chatMessageState.rawQuery;
+ // Reset variables
+ this.chatMessageState = {
+ newResponseTextEl: null,
+ newResponseEl: null,
+ loadingEllipsis: null,
+ references: {},
+ rawResponse: "",
+ rawQuery: liveQuery,
+ isVoice: false,
+ };
+ } else if (chunk.type === "references") {
+ this.chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext};
+ } else if (chunk.type === 'message') {
+ const chunkData = chunk.data;
+ if (typeof chunkData === 'object' && chunkData !== null) {
+ // If chunkData is already a JSON object
+ this.handleJsonResponse(chunkData);
+ } else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) {
+ // Try process chunk data as if it is a JSON object
+ try {
+ const jsonData = JSON.parse(chunkData.trim());
+ this.handleJsonResponse(jsonData);
+ } catch (e) {
+ this.chatMessageState.rawResponse += chunkData;
+ this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis);
+ }
+ } else {
+ this.chatMessageState.rawResponse += chunkData;
+ this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis);
+ }
+ }
+ }
+
+ handleJsonResponse(jsonData: any): void {
+ if (jsonData.image || jsonData.detail) {
+ this.chatMessageState.rawResponse = this.handleImageResponse(jsonData, this.chatMessageState.rawResponse);
+ } else if (jsonData.response) {
+ this.chatMessageState.rawResponse = jsonData.response;
+ }
+
+ if (this.chatMessageState.newResponseTextEl) {
+ this.chatMessageState.newResponseTextEl.innerHTML = "";
+ this.chatMessageState.newResponseTextEl.appendChild(this.formatHTMLMessage(this.chatMessageState.rawResponse));
+ }
+ }
+
+ async readChatStream(response: Response): Promise {
// Exit if response body is empty
if (response.body == null) return;
const reader = response.body.getReader();
const decoder = new TextDecoder();
+ const eventDelimiter = '␃🔚␗';
+ let buffer = '';
while (true) {
const { value, done } = await reader.read();
if (done) {
- // Automatically respond with voice if the subscribed user has sent voice message
- if (isVoice && this.setting.userInfo?.is_active) this.textToSpeech(this.result);
+ this.processMessageChunk(buffer);
+ buffer = '';
// Break if the stream is done
break;
}
- let responseText = decoder.decode(value);
- if (responseText.includes("### compiled references:")) {
- // Render any references used to generate the response
- const [additionalResponse, rawReference] = responseText.split("### compiled references:", 2);
- await this.renderIncrementalMessage(responseElement, additionalResponse);
+ const chunk = decoder.decode(value, { stream: true });
+ console.debug("Raw Chunk:", chunk)
+ // Start buffering chunks until complete event is received
+ buffer += chunk;
- const rawReferenceAsJson = JSON.parse(rawReference);
- let references = this.extractReferences(rawReferenceAsJson);
- responseElement.appendChild(this.createReferenceSection(references));
- } else {
- // Render incremental chat response
- await this.renderIncrementalMessage(responseElement, responseText);
+ // Once the buffer contains a complete event
+ let newEventIndex;
+ while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
+ // Extract the event from the buffer
+ const event = buffer.slice(0, newEventIndex);
+ buffer = buffer.slice(newEventIndex + eventDelimiter.length);
+
+ // Process the event
+ if (event) this.processMessageChunk(event);
}
}
}
@@ -895,83 +997,59 @@ export class KhojChatView extends KhojPaneView {
let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0] as HTMLElement;
this.renderMessage(chatBodyEl, query, "you");
- let conversationID = chatBodyEl.dataset.conversationId;
- if (!conversationID) {
+ let conversationId = chatBodyEl.dataset.conversationId;
+ if (!conversationId) {
let chatUrl = `${this.setting.khojUrl}/api/chat/sessions?client=obsidian`;
let response = await fetch(chatUrl, {
method: "POST",
headers: { "Authorization": `Bearer ${this.setting.khojApiKey}` },
});
let data = await response.json();
- conversationID = data.conversation_id;
- chatBodyEl.dataset.conversationId = conversationID;
+ conversationId = data.conversation_id;
+ chatBodyEl.dataset.conversationId = conversationId;
}
// Get chat response from Khoj backend
let encodedQuery = encodeURIComponent(query);
- let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian&stream=true®ion=${this.location.region}&city=${this.location.city}&country=${this.location.countryName}&timezone=${this.location.timezone}`;
- let responseElement = this.createKhojResponseDiv();
+ let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&conversation_id=${conversationId}&n=${this.setting.resultsCount}&stream=true&client=obsidian`;
+ if (!!this.location) chatUrl += `®ion=${this.location.region}&city=${this.location.city}&country=${this.location.countryName}&timezone=${this.location.timezone}`;
+
+ let newResponseEl = this.createKhojResponseDiv();
+ let newResponseTextEl = newResponseEl.createDiv();
+ newResponseTextEl.classList.add("khoj-chat-message-text", "khoj");
// Temporary status message to indicate that Khoj is thinking
- this.result = "";
let loadingEllipsis = this.createLoadingEllipse();
- responseElement.appendChild(loadingEllipsis);
+ newResponseTextEl.appendChild(loadingEllipsis);
+
+ // Set chat message state
+ this.chatMessageState = {
+ newResponseEl: newResponseEl,
+ newResponseTextEl: newResponseTextEl,
+ loadingEllipsis: loadingEllipsis,
+ references: {},
+ rawQuery: query,
+ rawResponse: "",
+ isVoice: isVoice,
+ };
let response = await fetch(chatUrl, {
method: "GET",
headers: {
- "Content-Type": "text/event-stream",
+ "Content-Type": "text/plain",
"Authorization": `Bearer ${this.setting.khojApiKey}`,
},
})
try {
- if (response.body === null) {
- throw new Error("Response body is null");
- }
+ if (response.body === null) throw new Error("Response body is null");
- // Clear loading status message
- if (responseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) {
- responseElement.removeChild(loadingEllipsis);
- }
-
- // Reset collated chat result to empty string
- this.result = "";
- responseElement.innerHTML = "";
- if (response.headers.get("content-type") === "application/json") {
- let responseText = ""
- try {
- const responseAsJson = await response.json() as ChatJsonResult;
- if (responseAsJson.image) {
- // If response has image field, response is a generated image.
- if (responseAsJson.intentType === "text-to-image") {
- responseText += `![${query}](data:image/png;base64,${responseAsJson.image})`;
- } else if (responseAsJson.intentType === "text-to-image2") {
- responseText += `![${query}](${responseAsJson.image})`;
- } else if (responseAsJson.intentType === "text-to-image-v3") {
- responseText += `![${query}](data:image/webp;base64,${responseAsJson.image})`;
- }
- const inferredQuery = responseAsJson.inferredQueries?.[0];
- if (inferredQuery) {
- responseText += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
- }
- } else if (responseAsJson.detail) {
- responseText = responseAsJson.detail;
- }
- } catch (error) {
- // If the chunk is not a JSON object, just display it as is
- responseText = await response.text();
- } finally {
- await this.renderIncrementalMessage(responseElement, responseText);
- }
- } else {
- // Stream and render chat response
- await this.readChatStream(response, responseElement, isVoice);
- }
+ // Stream and render chat response
+ await this.readChatStream(response);
} catch (err) {
- console.log(`Khoj chat response failed with\n${err}`);
+ console.error(`Khoj chat response failed with\n${err}`);
let errorMsg = "Sorry, unable to get response from Khoj backend ❤️🩹. Retry or contact developers for help at team@khoj.dev or on Discord";
- responseElement.innerHTML = errorMsg
+ newResponseTextEl.textContent = errorMsg;
}
}
@@ -1196,30 +1274,21 @@ export class KhojChatView extends KhojPaneView {
handleStreamResponse(newResponseElement: HTMLElement | null, rawResponse: string, loadingEllipsis: HTMLElement | null, replace = true) {
if (!newResponseElement) return;
- if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) {
+ // Remove loading ellipsis if it exists
+ if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis)
newResponseElement.removeChild(loadingEllipsis);
- }
- if (replace) {
- newResponseElement.innerHTML = "";
- }
+ // Clear the response element if replace is true
+ if (replace) newResponseElement.innerHTML = "";
+
+ // Append response to the response element
newResponseElement.appendChild(this.formatHTMLMessage(rawResponse, false, replace));
+
+ // Append loading ellipsis if it exists
+ if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis);
+ // Scroll to bottom of chat view
this.scrollChatToBottom();
}
- handleCompiledReferences(rawResponseElement: HTMLElement | null, chunk: string, references: any, rawResponse: string) {
- if (!rawResponseElement || !chunk) return { rawResponse, references };
-
- const [additionalResponse, rawReference] = chunk.split("### compiled references:", 2);
- rawResponse += additionalResponse;
- rawResponseElement.innerHTML = "";
- rawResponseElement.appendChild(this.formatHTMLMessage(rawResponse));
-
- const rawReferenceAsJson = JSON.parse(rawReference);
- references = this.extractReferences(rawReferenceAsJson);
-
- return { rawResponse, references };
- }
-
handleImageResponse(imageJson: any, rawResponse: string) {
if (imageJson.image) {
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
@@ -1236,33 +1305,10 @@ export class KhojChatView extends KhojPaneView {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
}
- let references = {};
- if (imageJson.context && imageJson.context.length > 0) {
- references = this.extractReferences(imageJson.context);
- }
- if (imageJson.detail) {
- // If response has detail field, response is an error message.
- rawResponse += imageJson.detail;
- }
- return { rawResponse, references };
- }
+ // If response has detail field, response is an error message.
+ if (imageJson.detail) rawResponse += imageJson.detail;
- extractReferences(rawReferenceAsJson: any): object {
- let references: any = {};
- if (rawReferenceAsJson instanceof Array) {
- references["notes"] = rawReferenceAsJson;
- } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
- references["online"] = rawReferenceAsJson;
- }
- return references;
- }
-
- addMessageToChatBody(rawResponse: string, newResponseElement: HTMLElement | null, references: any) {
- if (!newResponseElement) return;
- newResponseElement.innerHTML = "";
- newResponseElement.appendChild(this.formatHTMLMessage(rawResponse));
-
- this.finalizeChatBodyResponse(references, newResponseElement);
+ return rawResponse;
}
finalizeChatBodyResponse(references: object, newResponseElement: HTMLElement | null) {
diff --git a/src/interface/obsidian/styles.css b/src/interface/obsidian/styles.css
index afd8fd19..42c1b3ce 100644
--- a/src/interface/obsidian/styles.css
+++ b/src/interface/obsidian/styles.css
@@ -85,6 +85,12 @@ If your plugin does not need CSS, delete this file.
margin-left: auto;
white-space: pre-line;
}
+/* Override white-space for ul, ol, li under khoj-chat-message-text.khoj */
+.khoj-chat-message-text.khoj ul,
+.khoj-chat-message-text.khoj ol,
+.khoj-chat-message-text.khoj li {
+ white-space: normal;
+}
/* add left protrusion to khoj chat bubble */
.khoj-chat-message-text.khoj:after {
content: '';
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index a0a13e8c..0e50169e 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -680,34 +680,18 @@ class ConversationAdapters:
async def aget_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, title: str = None
) -> Optional[Conversation]:
+ query = Conversation.objects.filter(user=user, client=client_application).prefetch_related("agent")
+
if conversation_id:
- return (
- await Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
- .prefetch_related("agent")
- .afirst()
- )
+ return await query.filter(id=conversation_id).afirst()
elif title:
- return (
- await Conversation.objects.filter(user=user, client=client_application, title=title)
- .prefetch_related("agent")
- .afirst()
- )
- else:
- conversation = (
- Conversation.objects.filter(user=user, client=client_application)
- .prefetch_related("agent")
- .order_by("-updated_at")
- )
+ return await query.filter(title=title).afirst()
- if await conversation.aexists():
- return await conversation.prefetch_related("agent").afirst()
+ conversation = await query.order_by("-updated_at").afirst()
- return await (
- Conversation.objects.filter(user=user, client=client_application)
- .prefetch_related("agent")
- .order_by("-updated_at")
- .afirst()
- ) or await Conversation.objects.prefetch_related("agent").acreate(user=user, client=client_application)
+ return conversation or await Conversation.objects.prefetch_related("agent").acreate(
+ user=user, client=client_application
+ )
@staticmethod
async def adelete_conversation_by_user(
diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html
index ad8ced27..024af9ad 100644
--- a/src/khoj/interface/web/chat.html
+++ b/src/khoj/interface/web/chat.html
@@ -74,14 +74,13 @@ To get started, just start typing below. You can also type / to see a list of co
}, 1000);
});
}
- var websocket = null;
+
let region = null;
let city = null;
let countryName = null;
let timezone = null;
let waitingForLocation = true;
-
- let websocketState = {
+ let chatMessageState = {
newResponseTextEl: null,
newResponseEl: null,
loadingEllipsis: null,
@@ -105,7 +104,7 @@ To get started, just start typing below. You can also type / to see a list of co
.finally(() => {
console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone);
waitingForLocation = false;
- setupWebSocket();
+ initMessageState();
});
function formatDate(date) {
@@ -599,13 +598,8 @@ To get started, just start typing below. You can also type / to see a list of co
}
async function chat(isVoice=false) {
- if (websocket) {
- sendMessageViaWebSocket(isVoice);
- return;
- }
-
- let query = document.getElementById("chat-input").value.trim();
- let resultsCount = localStorage.getItem("khojResultsCount") || 5;
+ // Extract chat message from chat input form
+ var query = document.getElementById("chat-input").value.trim();
console.log(`Query: ${query}`);
// Short circuit on empty query
@@ -624,31 +618,30 @@ To get started, just start typing below. You can also type / to see a list of co
document.getElementById("chat-input").value = "";
autoResize();
document.getElementById("chat-input").setAttribute("disabled", "disabled");
- let chat_body = document.getElementById("chat-body");
-
- let conversationID = chat_body.dataset.conversationId;
+ let chatBody = document.getElementById("chat-body");
+ let conversationID = chatBody.dataset.conversationId;
if (!conversationID) {
- let response = await fetch('/api/chat/sessions', { method: "POST" });
+ let response = await fetch(`${hostURL}/api/chat/sessions`, { method: "POST" });
let data = await response.json();
conversationID = data.conversation_id;
- chat_body.dataset.conversationId = conversationID;
- refreshChatSessionsPanel();
+ chatBody.dataset.conversationId = conversationID;
+ await refreshChatSessionsPanel();
}
- let new_response = document.createElement("div");
- new_response.classList.add("chat-message", "khoj");
- new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
- chat_body.appendChild(new_response);
+ let newResponseEl = document.createElement("div");
+ newResponseEl.classList.add("chat-message", "khoj");
+ newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
+ chatBody.appendChild(newResponseEl);
- let newResponseText = document.createElement("div");
- newResponseText.classList.add("chat-message-text", "khoj");
- new_response.appendChild(newResponseText);
+ let newResponseTextEl = document.createElement("div");
+ newResponseTextEl.classList.add("chat-message-text", "khoj");
+ newResponseEl.appendChild(newResponseTextEl);
// Temporary status message to indicate that Khoj is thinking
let loadingEllipsis = createLoadingEllipse();
- newResponseText.appendChild(loadingEllipsis);
+ newResponseTextEl.appendChild(loadingEllipsis);
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
let chatTooltip = document.getElementById("chat-tooltip");
@@ -657,65 +650,38 @@ To get started, just start typing below. You can also type / to see a list of co
let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled");
- // Generate backend API URL to execute query
- let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`;
-
- // Call specified Khoj API
- let response = await fetch(url);
- let rawResponse = "";
- let references = null;
- const contentType = response.headers.get("content-type");
-
- if (contentType === "application/json") {
- // Handle JSON response
- try {
- const responseAsJson = await response.json();
- if (responseAsJson.image || responseAsJson.detail) {
- ({rawResponse, references } = handleImageResponse(responseAsJson, rawResponse));
- } else {
- rawResponse = responseAsJson.response;
- }
- } catch (error) {
- // If the chunk is not a JSON object, just display it as is
- rawResponse += chunk;
- } finally {
- addMessageToChatBody(rawResponse, newResponseText, references);
- }
- } else {
- // Handle streamed response of type text/event-stream or text/plain
- const reader = response.body.getReader();
- const decoder = new TextDecoder();
- let references = {};
-
- readStream();
-
- function readStream() {
- reader.read().then(({ done, value }) => {
- if (done) {
- // Append any references after all the data has been streamed
- finalizeChatBodyResponse(references, newResponseText);
- return;
- }
-
- // Decode message chunk from stream
- const chunk = decoder.decode(value, { stream: true });
-
- if (chunk.includes("### compiled references:")) {
- ({ rawResponse, references } = handleCompiledReferences(newResponseText, chunk, references, rawResponse));
- readStream();
- } else {
- // If the chunk is not a JSON object, just display it as is
- rawResponse += chunk;
- handleStreamResponse(newResponseText, rawResponse, query, loadingEllipsis);
- readStream();
- }
- });
-
- // Scroll to bottom of chat window as chat response is streamed
- document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
- };
+ // Setup chat message state
+ chatMessageState = {
+ newResponseTextEl,
+ newResponseEl,
+ loadingEllipsis,
+ references: {},
+ rawResponse: "",
+ rawQuery: query,
+ isVoice: isVoice,
}
- };
+
+ // Call Khoj chat API
+ let chatApi = `/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationID}&stream=true&client=web`;
+ chatApi += (!!region && !!city && !!countryName && !!timezone)
+ ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`
+ : '';
+
+ const response = await fetch(chatApi);
+
+ try {
+ if (!response.ok) throw new Error(response.statusText);
+ if (!response.body) throw new Error("Response body is empty");
+ // Stream and render chat response
+ await readChatStream(response);
+ } catch (err) {
+ console.error(`Khoj chat response failed with\n${err}`);
+ if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis)
+ chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
+ let errorMsg = "Sorry, unable to get response from Khoj backend ❤️🩹. Retry or contact developers for help at team@khoj.dev or on Discord";
+ newResponseTextEl.innerHTML = errorMsg;
+ }
+ }
function createLoadingEllipse() {
// Temporary status message to indicate that Khoj is thinking
@@ -743,32 +709,22 @@ To get started, just start typing below. You can also type / to see a list of co
}
function handleStreamResponse(newResponseElement, rawResponse, rawQuery, loadingEllipsis, replace=true) {
- if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) {
+ if (!newResponseElement) return;
+ // Remove loading ellipsis if it exists
+ if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis)
newResponseElement.removeChild(loadingEllipsis);
- }
- if (replace) {
- newResponseElement.innerHTML = "";
- }
+ // Clear the response element if replace is true
+ if (replace) newResponseElement.innerHTML = "";
+
+ // Append response to the response element
newResponseElement.appendChild(formatHTMLMessage(rawResponse, false, replace, rawQuery));
+
+ // Append loading ellipsis if it exists
+ if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis);
+ // Scroll to bottom of chat view
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
}
- function handleCompiledReferences(rawResponseElement, chunk, references, rawResponse) {
- const additionalResponse = chunk.split("### compiled references:")[0];
- rawResponse += additionalResponse;
- rawResponseElement.innerHTML = "";
- rawResponseElement.appendChild(formatHTMLMessage(rawResponse));
-
- const rawReference = chunk.split("### compiled references:")[1];
- const rawReferenceAsJson = JSON.parse(rawReference);
- if (rawReferenceAsJson instanceof Array) {
- references["notes"] = rawReferenceAsJson;
- } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
- references["online"] = rawReferenceAsJson;
- }
- return { rawResponse, references };
- }
-
function handleImageResponse(imageJson, rawResponse) {
if (imageJson.image) {
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
@@ -785,35 +741,139 @@ To get started, just start typing below. You can also type / to see a list of co
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
}
- let references = {};
- if (imageJson.context && imageJson.context.length > 0) {
- const rawReferenceAsJson = imageJson.context;
- if (rawReferenceAsJson instanceof Array) {
- references["notes"] = rawReferenceAsJson;
- } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
- references["online"] = rawReferenceAsJson;
- }
- }
- if (imageJson.detail) {
- // If response has detail field, response is an error message.
- rawResponse += imageJson.detail;
- }
- return { rawResponse, references };
- }
- function addMessageToChatBody(rawResponse, newResponseElement, references) {
- newResponseElement.innerHTML = "";
- newResponseElement.appendChild(formatHTMLMessage(rawResponse));
+ // If response has detail field, response is an error message.
+ if (imageJson.detail) rawResponse += imageJson.detail;
- finalizeChatBodyResponse(references, newResponseElement);
+ return rawResponse;
}
function finalizeChatBodyResponse(references, newResponseElement) {
- if (references != null && Object.keys(references).length > 0) {
+ if (!!newResponseElement && references != null && Object.keys(references).length > 0) {
newResponseElement.appendChild(createReferenceSection(references));
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
- document.getElementById("chat-input").removeAttribute("disabled");
+ document.getElementById("chat-input")?.removeAttribute("disabled");
+ }
+
+ function convertMessageChunkToJson(rawChunk) {
+ // Split the chunk into lines
+ console.debug("Raw Event:", rawChunk);
+ if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) {
+ try {
+ let jsonChunk = JSON.parse(rawChunk);
+ if (!jsonChunk.type)
+ jsonChunk = {type: 'message', data: jsonChunk};
+ return jsonChunk;
+ } catch (e) {
+ return {type: 'message', data: rawChunk};
+ }
+ } else if (rawChunk.length > 0) {
+ return {type: 'message', data: rawChunk};
+ }
+ }
+
+ function processMessageChunk(rawChunk) {
+ const chunk = convertMessageChunkToJson(rawChunk);
+ console.debug("Json Event:", chunk);
+ if (!chunk || !chunk.type) return;
+ if (chunk.type ==='status') {
+ console.log(`status: ${chunk.data}`);
+ const statusMessage = chunk.data;
+ handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, false);
+ } else if (chunk.type === 'start_llm_response') {
+ console.log("Started streaming", new Date());
+ } else if (chunk.type === 'end_llm_response') {
+ console.log("Stopped streaming", new Date());
+
+ // Automatically respond with voice if the subscribed user has sent voice message
+ if (chatMessageState.isVoice && "{{ is_active }}" == "True")
+ textToSpeech(chatMessageState.rawResponse);
+
+ // Append any references after all the data has been streamed
+ finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl);
+
+ const liveQuery = chatMessageState.rawQuery;
+ // Reset variables
+ chatMessageState = {
+ newResponseTextEl: null,
+ newResponseEl: null,
+ loadingEllipsis: null,
+ references: {},
+ rawResponse: "",
+ rawQuery: liveQuery,
+ isVoice: false,
+ }
+ } else if (chunk.type === "references") {
+ chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext};
+ } else if (chunk.type === 'message') {
+ const chunkData = chunk.data;
+ if (typeof chunkData === 'object' && chunkData !== null) {
+ // If chunkData is already a JSON object
+ handleJsonResponse(chunkData);
+ } else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) {
+ // Try process chunk data as if it is a JSON object
+ try {
+ const jsonData = JSON.parse(chunkData.trim());
+ handleJsonResponse(jsonData);
+ } catch (e) {
+ chatMessageState.rawResponse += chunkData;
+ handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
+ }
+ } else {
+ chatMessageState.rawResponse += chunkData;
+ handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
+ }
+ }
+ }
+
+ function handleJsonResponse(jsonData) {
+ if (jsonData.image || jsonData.detail) {
+ chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse);
+ } else if (jsonData.response) {
+ chatMessageState.rawResponse = jsonData.response;
+ }
+
+ if (chatMessageState.newResponseTextEl) {
+ chatMessageState.newResponseTextEl.innerHTML = "";
+ chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse));
+ }
+ }
+
+ async function readChatStream(response) {
+ if (!response.body) return;
+ const reader = response.body.getReader();
+ const decoder = new TextDecoder();
+ const eventDelimiter = '␃🔚␗';
+ let buffer = '';
+
+ while (true) {
+ const { value, done } = await reader.read();
+ // If the stream is done
+ if (done) {
+ // Process the last chunk
+ processMessageChunk(buffer);
+ buffer = '';
+ break;
+ }
+
+ // Read chunk from stream and append it to the buffer
+ const chunk = decoder.decode(value, { stream: true });
+ console.debug("Raw Chunk:", chunk)
+ // Start buffering chunks until complete event is received
+ buffer += chunk;
+
+ // Once the buffer contains a complete event
+ let newEventIndex;
+ while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
+ // Extract the event from the buffer
+ const event = buffer.slice(0, newEventIndex);
+ buffer = buffer.slice(newEventIndex + eventDelimiter.length);
+
+ // Process the event
+ if (event) processMessageChunk(event);
+ }
+ }
}
function incrementalChat(event) {
@@ -1069,17 +1129,13 @@ To get started, just start typing below. You can also type / to see a list of co
window.onload = loadChat;
- function setupWebSocket(isVoice=false) {
- let chatBody = document.getElementById("chat-body");
- let wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
- let webSocketUrl = `${wsProtocol}//${window.location.host}/api/chat/ws`;
-
+ function initMessageState(isVoice=false) {
if (waitingForLocation) {
console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available.");
return;
}
- websocketState = {
+ chatMessageState = {
newResponseTextEl: null,
newResponseEl: null,
loadingEllipsis: null,
@@ -1088,174 +1144,8 @@ To get started, just start typing below. You can also type / to see a list of co
rawQuery: "",
isVoice: isVoice,
}
-
- if (chatBody.dataset.conversationId) {
- webSocketUrl += `?conversation_id=${chatBody.dataset.conversationId}`;
- webSocketUrl += (!!region && !!city && !!countryName) && !!timezone ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : '';
-
- websocket = new WebSocket(webSocketUrl);
- websocket.onmessage = function(event) {
-
- // Get the last element in the chat-body
- let chunk = event.data;
- if (chunk == "start_llm_response") {
- console.log("Started streaming", new Date());
- } else if (chunk == "end_llm_response") {
- console.log("Stopped streaming", new Date());
-
- // Automatically respond with voice if the subscribed user has sent voice message
- if (websocketState.isVoice && "{{ is_active }}" == "True")
- textToSpeech(websocketState.rawResponse);
-
- // Append any references after all the data has been streamed
- finalizeChatBodyResponse(websocketState.references, websocketState.newResponseTextEl);
-
- const liveQuery = websocketState.rawQuery;
- // Reset variables
- websocketState = {
- newResponseTextEl: null,
- newResponseEl: null,
- loadingEllipsis: null,
- references: {},
- rawResponse: "",
- rawQuery: liveQuery,
- isVoice: false,
- }
- } else {
- try {
- if (chunk.includes("application/json"))
- {
- chunk = JSON.parse(chunk);
- }
- } catch (error) {
- // If the chunk is not a JSON object, continue.
- }
-
- const contentType = chunk["content-type"]
-
- if (contentType === "application/json") {
- // Handle JSON response
- try {
- if (chunk.image || chunk.detail) {
- ({rawResponse, references } = handleImageResponse(chunk, websocketState.rawResponse));
- websocketState.rawResponse = rawResponse;
- websocketState.references = references;
- } else if (chunk.type == "status") {
- handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, null, false);
- } else if (chunk.type == "rate_limit") {
- handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, websocketState.loadingEllipsis, true);
- } else {
- rawResponse = chunk.response;
- }
- } catch (error) {
- // If the chunk is not a JSON object, just display it as is
- websocketState.rawResponse += chunk;
- } finally {
- if (chunk.type != "status" && chunk.type != "rate_limit") {
- addMessageToChatBody(websocketState.rawResponse, websocketState.newResponseTextEl, websocketState.references);
- }
- }
- } else {
-
- // Handle streamed response of type text/event-stream or text/plain
- if (chunk && chunk.includes("### compiled references:")) {
- ({ rawResponse, references } = handleCompiledReferences(websocketState.newResponseTextEl, chunk, websocketState.references, websocketState.rawResponse));
- websocketState.rawResponse = rawResponse;
- websocketState.references = references;
- } else {
- // If the chunk is not a JSON object, just display it as is
- websocketState.rawResponse += chunk;
- if (websocketState.newResponseTextEl) {
- handleStreamResponse(websocketState.newResponseTextEl, websocketState.rawResponse, websocketState.rawQuery, websocketState.loadingEllipsis);
- }
- }
-
- // Scroll to bottom of chat window as chat response is streamed
- document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
- };
- }
- }
- };
- websocket.onclose = function(event) {
- websocket = null;
- console.log("WebSocket is closed now.");
- let setupWebSocketButton = document.createElement("button");
- setupWebSocketButton.textContent = "Reconnect to Server";
- setupWebSocketButton.onclick = setupWebSocket;
- let statusDotIcon = document.getElementById("connection-status-icon");
- statusDotIcon.style.backgroundColor = "red";
- let statusDotText = document.getElementById("connection-status-text");
- statusDotText.innerHTML = "";
- statusDotText.style.marginTop = "5px";
- statusDotText.appendChild(setupWebSocketButton);
- }
- websocket.onerror = function(event) {
- console.log("WebSocket error observed:", event);
- }
-
- websocket.onopen = function(event) {
- console.log("WebSocket is open now.")
- let statusDotIcon = document.getElementById("connection-status-icon");
- statusDotIcon.style.backgroundColor = "green";
- let statusDotText = document.getElementById("connection-status-text");
- statusDotText.textContent = "Connected to Server";
- }
}
- function sendMessageViaWebSocket(isVoice=false) {
- let chatBody = document.getElementById("chat-body");
-
- var query = document.getElementById("chat-input").value.trim();
- console.log(`Query: ${query}`);
-
- if (userMessages.length >= 10) {
- userMessages.shift();
- }
- userMessages.push(query);
- resetUserMessageIndex();
-
- // Add message by user to chat body
- renderMessage(query, "you");
- document.getElementById("chat-input").value = "";
- autoResize();
- document.getElementById("chat-input").setAttribute("disabled", "disabled");
-
- let newResponseEl = document.createElement("div");
- newResponseEl.classList.add("chat-message", "khoj");
- newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
- chatBody.appendChild(newResponseEl);
-
- let newResponseTextEl = document.createElement("div");
- newResponseTextEl.classList.add("chat-message-text", "khoj");
- newResponseEl.appendChild(newResponseTextEl);
-
- // Temporary status message to indicate that Khoj is thinking
- let loadingEllipsis = createLoadingEllipse();
-
- newResponseTextEl.appendChild(loadingEllipsis);
- document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
-
- let chatTooltip = document.getElementById("chat-tooltip");
- chatTooltip.style.display = "none";
-
- let chatInput = document.getElementById("chat-input");
- chatInput.classList.remove("option-enabled");
-
- // Call specified Khoj API
- websocket.send(query);
- let rawResponse = "";
- let references = {};
-
- websocketState = {
- newResponseTextEl,
- newResponseEl,
- loadingEllipsis,
- references,
- rawResponse,
- rawQuery: query,
- isVoice: isVoice,
- }
- }
var userMessages = [];
var userMessageIndex = -1;
function loadChat() {
@@ -1265,7 +1155,7 @@ To get started, just start typing below. You can also type / to see a list of co
let chatHistoryUrl = `/api/chat/history?client=web`;
if (chatBody.dataset.conversationId) {
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
- setupWebSocket();
+ initMessageState();
loadFileFiltersFromConversation();
}
@@ -1305,7 +1195,7 @@ To get started, just start typing below. You can also type / to see a list of co
let chatBody = document.getElementById("chat-body");
chatBody.dataset.conversationId = response.conversation_id;
loadFileFiltersFromConversation();
- setupWebSocket();
+ initMessageState();
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
let agentMetadata = response.agent;
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index 797066d7..ea7368e6 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -62,10 +62,6 @@ class ThreadedGenerator:
self.queue.put(data)
def close(self):
- if self.compiled_references and len(self.compiled_references) > 0:
- self.queue.put(f"### compiled references:{json.dumps(self.compiled_references)}")
- if self.online_results and len(self.online_results) > 0:
- self.queue.put(f"### compiled references:{json.dumps(self.online_results)}")
self.queue.put(StopIteration)
diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py
index 72191077..c087de70 100644
--- a/src/khoj/processor/tools/online_search.py
+++ b/src/khoj/processor/tools/online_search.py
@@ -11,6 +11,7 @@ from bs4 import BeautifulSoup
from markdownify import markdownify
from khoj.routers.helpers import (
+ ChatEvent,
extract_relevant_info,
generate_online_subqueries,
infer_webpage_urls,
@@ -56,7 +57,8 @@ async def search_online(
query += " ".join(custom_filters)
if not is_internet_connected():
logger.warn("Cannot search online as not connected to internet")
- return {}
+ yield {}
+ return
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(query, conversation_history, location)
@@ -66,7 +68,8 @@ async def search_online(
logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
if send_status_func:
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
- await send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}")
+ async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"):
+ yield {ChatEvent.STATUS: event}
with timer(f"Internet searches for {list(subqueries)} took", logger):
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
@@ -89,7 +92,8 @@ async def search_online(
logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}")
if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
- await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}")
+ async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
+ yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
results = await asyncio.gather(*tasks)
@@ -98,7 +102,7 @@ async def search_online(
if webpage_extract is not None:
response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract}
- return response_dict
+ yield response_dict
async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
@@ -127,13 +131,15 @@ async def read_webpages(
"Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read")
if send_status_func:
- await send_status_func(f"**🧐 Inferring web pages to read**")
+ async for event in send_status_func(f"**🧐 Inferring web pages to read**"):
+ yield {ChatEvent.STATUS: event}
urls = await infer_webpage_urls(query, conversation_history, location)
logger.info(f"Reading web pages at: {urls}")
if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(urls))
- await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}")
+ async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
+ yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(query, url) for url in urls]
results = await asyncio.gather(*tasks)
@@ -141,7 +147,7 @@ async def read_webpages(
response[query]["webpages"] = [
{"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None
]
- return response
+ yield response
async def read_webpage_and_extract_content(
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index c9d76ae7..15d7cbc7 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -6,7 +6,6 @@ import os
import threading
import time
import uuid
-from random import random
from typing import Any, Callable, List, Optional, Union
import cron_descriptor
@@ -37,6 +36,7 @@ from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.routers.helpers import (
ApiUserRateLimiter,
+ ChatEvent,
CommonQueryParams,
ConversationCommandRateLimiter,
acreate_title_from_query,
@@ -298,11 +298,13 @@ async def extract_references_and_questions(
not ConversationCommand.Notes in conversation_commands
and not ConversationCommand.Default in conversation_commands
):
- return compiled_references, inferred_queries, q
+ yield compiled_references, inferred_queries, q
+ return
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
- return compiled_references, inferred_queries, q
+ yield compiled_references, inferred_queries, q
+ return
# Extract filter terms from user message
defiltered_query = q
@@ -313,7 +315,8 @@ async def extract_references_and_questions(
if not conversation:
logger.error(f"Conversation with id {conversation_id} not found.")
- return compiled_references, inferred_queries, defiltered_query
+ yield compiled_references, inferred_queries, defiltered_query
+ return
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
using_offline_chat = False
@@ -373,7 +376,8 @@ async def extract_references_and_questions(
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
if send_status_func:
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
- await send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}")
+ async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"):
+ yield {ChatEvent.STATUS: event}
for query in inferred_queries:
n_items = min(n, 3) if using_offline_chat else n
search_results.extend(
@@ -392,7 +396,7 @@ async def extract_references_and_questions(
{"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results
]
- return compiled_references, inferred_queries, defiltered_query
+ yield compiled_references, inferred_queries, defiltered_query
@api.get("/health", response_class=Response)
diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py
index be28622b..63529b8e 100644
--- a/src/khoj/routers/api_chat.py
+++ b/src/khoj/routers/api_chat.py
@@ -1,17 +1,17 @@
+import asyncio
import json
import logging
-import math
+import time
from datetime import datetime
+from functools import partial
from typing import Any, Dict, List, Optional
from urllib.parse import unquote
from asgiref.sync import sync_to_async
-from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
+from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires
-from starlette.websockets import WebSocketDisconnect
-from websockets import ConnectionClosedOK
from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import (
@@ -23,19 +23,15 @@ from khoj.database.adapters import (
aget_user_name,
)
from khoj.database.models import KhojUser
-from khoj.processor.conversation.prompts import (
- help_message,
- no_entries_found,
- no_notes_found,
-)
+from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import (
ApiUserRateLimiter,
+ ChatEvent,
CommonQueryParams,
- CommonQueryParamsClass,
ConversationCommandRateLimiter,
agenerate_chat_response,
aget_relevant_information_sources,
@@ -526,141 +522,142 @@ async def set_conversation_title(
)
-@api_chat.websocket("/ws")
-async def websocket_endpoint(
- websocket: WebSocket,
- conversation_id: int,
+@api_chat.get("")
+async def chat(
+ request: Request,
+ common: CommonQueryParams,
+ q: str,
+ n: int = 7,
+ d: float = 0.18,
+ stream: Optional[bool] = False,
+ title: Optional[str] = None,
+ conversation_id: Optional[int] = None,
city: Optional[str] = None,
region: Optional[str] = None,
country: Optional[str] = None,
timezone: Optional[str] = None,
+ rate_limiter_per_minute=Depends(
+ ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
+ ),
+ rate_limiter_per_day=Depends(
+ ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
+ ),
):
- connection_alive = True
+ async def event_generator(q: str):
+ start_time = time.perf_counter()
+ ttft = None
+ chat_metadata: dict = {}
+ connection_alive = True
+ user: KhojUser = request.user.object
+ event_delimiter = "␃🔚␗"
+ q = unquote(q)
- async def send_status_update(message: str):
- nonlocal connection_alive
- if not connection_alive:
+ async def send_event(event_type: ChatEvent, data: str | dict):
+ nonlocal connection_alive, ttft
+ if not connection_alive or await request.is_disconnected():
+ connection_alive = False
+ logger.warn(f"User {user} disconnected from {common.client} client")
+ return
+ try:
+ if event_type == ChatEvent.END_LLM_RESPONSE:
+ collect_telemetry()
+ if event_type == ChatEvent.START_LLM_RESPONSE:
+ ttft = time.perf_counter() - start_time
+ if event_type == ChatEvent.MESSAGE:
+ yield data
+ elif event_type == ChatEvent.REFERENCES or stream:
+ yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
+ except asyncio.CancelledError as e:
+ connection_alive = False
+ logger.warn(f"User {user} disconnected from {common.client} client: {e}")
+ return
+ except Exception as e:
+ connection_alive = False
+ logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
+ return
+ finally:
+ if stream:
+ yield event_delimiter
+
+ async def send_llm_response(response: str):
+ async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
+ yield result
+ async for result in send_event(ChatEvent.MESSAGE, response):
+ yield result
+ async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
+ yield result
+
+ def collect_telemetry():
+ # Gather chat response telemetry
+ nonlocal chat_metadata
+ latency = time.perf_counter() - start_time
+ cmd_set = set([cmd.value for cmd in conversation_commands])
+ chat_metadata = chat_metadata or {}
+ chat_metadata["conversation_command"] = cmd_set
+ chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
+ chat_metadata["latency"] = f"{latency:.3f}"
+ chat_metadata["ttft_latency"] = f"{ttft:.3f}"
+
+ logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
+ logger.info(f"Chat response total time: {latency:.3f} seconds")
+ update_telemetry_state(
+ request=request,
+ telemetry_type="api",
+ api="chat",
+ client=request.user.client_app,
+ user_agent=request.headers.get("user-agent"),
+ host=request.headers.get("host"),
+ metadata=chat_metadata,
+ )
+
+ conversation = await ConversationAdapters.aget_conversation_by_user(
+ user, client_application=request.user.client_app, conversation_id=conversation_id, title=title
+ )
+ if not conversation:
+ async for result in send_llm_response(f"Conversation {conversation_id} not found"):
+ yield result
return
- status_packet = {
- "type": "status",
- "message": message,
- "content-type": "application/json",
- }
- try:
- await websocket.send_text(json.dumps(status_packet))
- except ConnectionClosedOK:
- connection_alive = False
- logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
+ await is_ready_to_chat(user)
- async def send_complete_llm_response(llm_response: str):
- nonlocal connection_alive
- if not connection_alive:
- return
- try:
- await websocket.send_text("start_llm_response")
- await websocket.send_text(llm_response)
- await websocket.send_text("end_llm_response")
- except ConnectionClosedOK:
- connection_alive = False
- logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
-
- async def send_message(message: str):
- nonlocal connection_alive
- if not connection_alive:
- return
- try:
- await websocket.send_text(message)
- except ConnectionClosedOK:
- connection_alive = False
- logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
-
- async def send_rate_limit_message(message: str):
- nonlocal connection_alive
- if not connection_alive:
- return
-
- status_packet = {
- "type": "rate_limit",
- "message": message,
- "content-type": "application/json",
- }
- try:
- await websocket.send_text(json.dumps(status_packet))
- except ConnectionClosedOK:
- connection_alive = False
- logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
-
- user: KhojUser = websocket.user.object
- conversation = await ConversationAdapters.aget_conversation_by_user(
- user, client_application=websocket.user.client_app, conversation_id=conversation_id
- )
-
- hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
-
- daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
-
- await is_ready_to_chat(user)
-
- user_name = await aget_user_name(user)
-
- location = None
-
- if city or region or country:
- location = LocationData(city=city, region=region, country=country)
-
- await websocket.accept()
- while connection_alive:
- try:
- if conversation:
- await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"])
- q = await websocket.receive_text()
-
- # Refresh these because the connection to the database might have been closed
- await conversation.arefresh_from_db()
-
- except WebSocketDisconnect:
- logger.debug(f"User {user} disconnected web socket")
- break
-
- try:
- await sync_to_async(hourly_limiter)(websocket)
- await sync_to_async(daily_limiter)(websocket)
- except HTTPException as e:
- await send_rate_limit_message(e.detail)
- break
+ user_name = await aget_user_name(user)
+ location = None
+ if city or region or country:
+ location = LocationData(city=city, region=region, country=country)
if is_query_empty(q):
- await send_message("start_llm_response")
- await send_message(
- "It seems like your query is incomplete. Could you please provide more details or specify what you need help with?"
- )
- await send_message("end_llm_response")
- continue
+ async for result in send_llm_response("Please ask your query to get started."):
+ yield result
+ return
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_commands = [get_conversation_command(query=q, any_references=True)]
- await send_status_update(f"**👀 Understanding Query**: {q}")
+ async for result in send_event(ChatEvent.STATUS, f"**👀 Understanding Query**: {q}"):
+ yield result
meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
- used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
- await send_status_update(f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}")
+ async for result in send_event(
+ ChatEvent.STATUS, f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
+ ):
+ yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
- await send_status_update(f"**🧑🏾💻 Decided Response Mode:** {mode.value}")
+ async for result in send_event(ChatEvent.STATUS, f"**🧑🏾💻 Decided Response Mode:** {mode.value}"):
+ yield result
if mode not in conversation_commands:
conversation_commands.append(mode)
for cmd in conversation_commands:
- await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
+ await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
+ used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if
if (
@@ -676,28 +673,37 @@ async def websocket_endpoint(
response_log = ""
if len(file_filters) == 0:
response_log = "No files selected for summarization. Please add files using the section on the left."
- await send_complete_llm_response(response_log)
+ async for result in send_llm_response(response_log):
+ yield result
elif len(file_filters) > 1:
response_log = "Only one file can be selected for summarization."
- await send_complete_llm_response(response_log)
+ async for result in send_llm_response(response_log):
+ yield result
else:
try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
- await send_complete_llm_response(response_log)
- continue
+ async for result in send_llm_response(response_log):
+ yield result
+ return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
- await send_status_update(f"**🧑🏾💻 Constructing Summary Using:** {file_object[0].file_name}")
+ async for result in send_event(
+ ChatEvent.STATUS, f"**🧑🏾💻 Constructing Summary Using:** {file_object[0].file_name}"
+ ):
+ yield result
+
response = await extract_relevant_summary(q, contextual_data)
response_log = str(response)
- await send_complete_llm_response(response_log)
+ async for result in send_llm_response(response_log):
+ yield result
except Exception as e:
response_log = "Error summarizing file."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
- await send_complete_llm_response(response_log)
+ async for result in send_llm_response(response_log):
+ yield result
await sync_to_async(save_to_conversation_log)(
q,
response_log,
@@ -705,16 +711,10 @@ async def websocket_endpoint(
meta_log,
user_message_time,
intent_type="summarize",
- client_application=websocket.user.client_app,
+ client_application=request.user.client_app,
conversation_id=conversation_id,
)
- update_telemetry_state(
- request=websocket,
- telemetry_type="api",
- api="chat",
- metadata={"conversation_command": conversation_commands[0].value},
- )
- continue
+ return
custom_filters = []
if conversation_commands == [ConversationCommand.Help]:
@@ -724,8 +724,9 @@ async def websocket_endpoint(
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
- await send_complete_llm_response(formatted_help)
- continue
+ async for result in send_llm_response(formatted_help):
+ yield result
+ return
# Adding specification to search online specifically on khoj.dev pages.
custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online)
@@ -733,14 +734,14 @@ async def websocket_endpoint(
if ConversationCommand.Automation in conversation_commands:
try:
automation, crontime, query_to_run, subject = await create_automation(
- q, timezone, user, websocket.url, meta_log
+ q, timezone, user, request.url, meta_log
)
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
- await send_complete_llm_response(
- f"Unable to create automation. Ensure the automation doesn't already exist."
- )
- continue
+ error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
+ async for result in send_llm_response(error_message):
+ yield result
+ return
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
await sync_to_async(save_to_conversation_log)(
@@ -750,57 +751,78 @@ async def websocket_endpoint(
meta_log,
user_message_time,
intent_type="automation",
- client_application=websocket.user.client_app,
+ client_application=request.user.client_app,
conversation_id=conversation_id,
inferred_queries=[query_to_run],
automation_id=automation.id,
)
- common = CommonQueryParamsClass(
- client=websocket.user.client_app,
- user_agent=websocket.headers.get("user-agent"),
- host=websocket.headers.get("host"),
- )
- update_telemetry_state(
- request=websocket,
- telemetry_type="api",
- api="chat",
- **common.__dict__,
- )
- await send_complete_llm_response(llm_response)
- continue
+ async for result in send_llm_response(llm_response):
+ yield result
+ return
- compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
- websocket, meta_log, q, 7, 0.18, conversation_id, conversation_commands, location, send_status_update
- )
+ # Gather Context
+ ## Extract Document References
+ compiled_references, inferred_queries, defiltered_query = [], [], None
+ async for result in extract_references_and_questions(
+ request,
+ meta_log,
+ q,
+ (n or 7),
+ (d or 0.18),
+ conversation_id,
+ conversation_commands,
+ location,
+ partial(send_event, ChatEvent.STATUS),
+ ):
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
+ yield result[ChatEvent.STATUS]
+ else:
+ compiled_references.extend(result[0])
+ inferred_queries.extend(result[1])
+ defiltered_query = result[2]
- if compiled_references:
+ if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
- await send_status_update(f"**📜 Found Relevant Notes**: {headings}")
+ async for result in send_event(ChatEvent.STATUS, f"**📜 Found Relevant Notes**: {headings}"):
+ yield result
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
- await send_complete_llm_response(f"{no_entries_found.format()}")
- continue
+ async for result in send_llm_response(f"{no_entries_found.format()}"):
+ yield result
+ return
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
+ ## Gather Online References
if ConversationCommand.Online in conversation_commands:
try:
- online_results = await search_online(
- defiltered_query, meta_log, location, send_status_update, custom_filters
- )
+ async for result in search_online(
+ defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters
+ ):
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
+ yield result[ChatEvent.STATUS]
+ else:
+ online_results = result
except ValueError as e:
- logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
- await send_complete_llm_response(
- f"Error searching online: {e}. Attempting to respond without online results"
- )
- continue
+ error_message = f"Error searching online: {e}. Attempting to respond without online results"
+ logger.warning(error_message)
+ async for result in send_llm_response(error_message):
+ yield result
+ return
+ ## Gather Webpage References
if ConversationCommand.Webpage in conversation_commands:
try:
- direct_web_pages = await read_webpages(defiltered_query, meta_log, location, send_status_update)
+ async for result in read_webpages(
+ defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS)
+ ):
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
+ yield result[ChatEvent.STATUS]
+ else:
+ direct_web_pages = result
webpages = []
for query in direct_web_pages:
if online_results.get(query):
@@ -810,38 +832,52 @@ async def websocket_endpoint(
for webpage in direct_web_pages[query]["webpages"]:
webpages.append(webpage["link"])
-
- await send_status_update(f"**📚 Read web pages**: {webpages}")
+ async for result in send_event(ChatEvent.STATUS, f"**📚 Read web pages**: {webpages}"):
+ yield result
except ValueError as e:
logger.warning(
- f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
+ f"Error directly reading webpages: {e}. Attempting to respond without online results",
+ exc_info=True,
)
+ ## Send Gathered References
+ async for result in send_event(
+ ChatEvent.REFERENCES,
+ {
+ "inferredQueries": inferred_queries,
+ "context": compiled_references,
+ "onlineContext": online_results,
+ },
+ ):
+ yield result
+
+ # Generate Output
+ ## Generate Image Output
if ConversationCommand.Image in conversation_commands:
- update_telemetry_state(
- request=websocket,
- telemetry_type="api",
- api="chat",
- metadata={"conversation_command": conversation_commands[0].value},
- )
- image, status_code, improved_image_prompt, intent_type = await text_to_image(
+ async for result in text_to_image(
q,
user,
meta_log,
location_data=location,
references=compiled_references,
online_results=online_results,
- send_status_func=send_status_update,
- )
+ send_status_func=partial(send_event, ChatEvent.STATUS),
+ ):
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
+ yield result[ChatEvent.STATUS]
+ else:
+ image, status_code, improved_image_prompt, intent_type = result
+
if image is None or status_code != 200:
content_obj = {
- "image": image,
+ "content-type": "application/json",
"intentType": intent_type,
"detail": improved_image_prompt,
- "content-type": "application/json",
+ "image": image,
}
- await send_complete_llm_response(json.dumps(content_obj))
- continue
+ async for result in send_llm_response(json.dumps(content_obj)):
+ yield result
+ return
await sync_to_async(save_to_conversation_log)(
q,
@@ -851,17 +887,23 @@ async def websocket_endpoint(
user_message_time,
intent_type=intent_type,
inferred_queries=[improved_image_prompt],
- client_application=websocket.user.client_app,
+ client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
)
- content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore
+ content_obj = {
+ "intentType": intent_type,
+ "inferredQueries": [improved_image_prompt],
+ "image": image,
+ }
+ async for result in send_llm_response(json.dumps(content_obj)):
+ yield result
+ return
- await send_complete_llm_response(json.dumps(content_obj))
- continue
-
- await send_status_update(f"**💭 Generating a well-informed response**")
+ ## Generate Text Output
+ async for result in send_event(ChatEvent.STATUS, f"**💭 Generating a well-informed response**"):
+ yield result
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
@@ -871,310 +913,49 @@ async def websocket_endpoint(
inferred_queries,
conversation_commands,
user,
- websocket.user.client_app,
+ request.user.client_app,
conversation_id,
location,
user_name,
)
- chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
+ # Send Response
+ async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
+ yield result
- update_telemetry_state(
- request=websocket,
- telemetry_type="api",
- api="chat",
- metadata=chat_metadata,
- )
+ continue_stream = True
iterator = AsyncIteratorWrapper(llm_response)
-
- await send_message("start_llm_response")
-
async for item in iterator:
if item is None:
- break
- if connection_alive:
- try:
- await send_message(f"{item}")
- except ConnectionClosedOK:
- connection_alive = False
- logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
-
- await send_message("end_llm_response")
-
-
-@api_chat.get("", response_class=Response)
-@requires(["authenticated"])
-async def chat(
- request: Request,
- common: CommonQueryParams,
- q: str,
- n: Optional[int] = 5,
- d: Optional[float] = 0.22,
- stream: Optional[bool] = False,
- title: Optional[str] = None,
- conversation_id: Optional[int] = None,
- city: Optional[str] = None,
- region: Optional[str] = None,
- country: Optional[str] = None,
- timezone: Optional[str] = None,
- rate_limiter_per_minute=Depends(
- ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
- ),
- rate_limiter_per_day=Depends(
- ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
- ),
-) -> Response:
- user: KhojUser = request.user.object
- q = unquote(q)
- if is_query_empty(q):
- return Response(
- content="It seems like your query is incomplete. Could you please provide more details or specify what you need help with?",
- media_type="text/plain",
- status_code=400,
- )
- user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- logger.info(f"Chat request by {user.username}: {q}")
-
- await is_ready_to_chat(user)
- conversation_commands = [get_conversation_command(query=q, any_references=True)]
-
- _custom_filters = []
- if conversation_commands == [ConversationCommand.Help]:
- help_str = "/" + ConversationCommand.Help
- if q.strip() == help_str:
- conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
- if conversation_config == None:
- conversation_config = await ConversationAdapters.aget_default_conversation_config()
- model_type = conversation_config.model_type
- formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
- return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
- # Adding specification to search online specifically on khoj.dev pages.
- _custom_filters.append("site:khoj.dev")
- conversation_commands.append(ConversationCommand.Online)
-
- conversation = await ConversationAdapters.aget_conversation_by_user(
- user, request.user.client_app, conversation_id, title
- )
- conversation_id = conversation.id if conversation else None
-
- if not conversation:
- return Response(
- content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400
- )
- else:
- meta_log = conversation.conversation_log
-
- if ConversationCommand.Summarize in conversation_commands:
- file_filters = conversation.file_filters
- llm_response = ""
- if len(file_filters) == 0:
- llm_response = "No files selected for summarization. Please add files using the section on the left."
- elif len(file_filters) > 1:
- llm_response = "Only one file can be selected for summarization."
- else:
+ async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
+ yield result
+ logger.debug("Finished streaming response")
+ return
+ if not connection_alive or not continue_stream:
+ continue
try:
- file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
- if len(file_object) == 0:
- llm_response = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
- return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
- contextual_data = " ".join([file.raw_text for file in file_object])
- summarizeStr = "/" + ConversationCommand.Summarize
- if q.strip() == summarizeStr:
- q = "Create a general summary of the file"
- response = await extract_relevant_summary(q, contextual_data)
- llm_response = str(response)
+ async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
+ yield result
except Exception as e:
- logger.error(f"Error summarizing file for {user.email}: {e}")
- llm_response = "Error summarizing file."
- await sync_to_async(save_to_conversation_log)(
- q,
- llm_response,
- user,
- conversation.conversation_log,
- user_message_time,
- intent_type="summarize",
- client_application=request.user.client_app,
- conversation_id=conversation_id,
- )
- update_telemetry_state(
- request=request,
- telemetry_type="api",
- api="chat",
- metadata={"conversation_command": conversation_commands[0].value},
- **common.__dict__,
- )
- return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
-
- is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
-
- if conversation_commands == [ConversationCommand.Default] or is_automated_task:
- conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
- mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
- if mode not in conversation_commands:
- conversation_commands.append(mode)
-
- for cmd in conversation_commands:
- await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
- q = q.replace(f"/{cmd.value}", "").strip()
-
- location = None
-
- if city or region or country:
- location = LocationData(city=city, region=region, country=country)
-
- user_name = await aget_user_name(user)
-
- if ConversationCommand.Automation in conversation_commands:
- try:
- automation, crontime, query_to_run, subject = await create_automation(
- q, timezone, user, request.url, meta_log
- )
- except Exception as e:
- logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True)
- return Response(
- content=f"Unable to create automation. Ensure the automation doesn't already exist.",
- media_type="text/plain",
- status_code=500,
- )
-
- llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
- await sync_to_async(save_to_conversation_log)(
- q,
- llm_response,
- user,
- meta_log,
- user_message_time,
- intent_type="automation",
- client_application=request.user.client_app,
- conversation_id=conversation_id,
- inferred_queries=[query_to_run],
- automation_id=automation.id,
- )
-
- if stream:
- return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
- else:
- return Response(content=llm_response, media_type="text/plain", status_code=200)
-
- compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
- request, meta_log, q, (n or 5), (d or math.inf), conversation_id, conversation_commands, location
- )
- online_results: Dict[str, Dict] = {}
-
- if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
- no_entries_found_format = no_entries_found.format()
- if stream:
- return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
- else:
- response_obj = {"response": no_entries_found_format}
- return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
-
- if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
- no_notes_found_format = no_notes_found.format()
- if stream:
- return StreamingResponse(iter([no_notes_found_format]), media_type="text/event-stream", status_code=200)
- else:
- response_obj = {"response": no_notes_found_format}
- return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
-
- if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
- conversation_commands.remove(ConversationCommand.Notes)
-
- if ConversationCommand.Online in conversation_commands:
- try:
- online_results = await search_online(defiltered_query, meta_log, location, custom_filters=_custom_filters)
- except ValueError as e:
- logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
-
- if ConversationCommand.Webpage in conversation_commands:
- try:
- online_results = await read_webpages(defiltered_query, meta_log, location)
- except ValueError as e:
- logger.warning(
- f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
- )
-
- if ConversationCommand.Image in conversation_commands:
- update_telemetry_state(
- request=request,
- telemetry_type="api",
- api="chat",
- metadata={"conversation_command": conversation_commands[0].value},
- **common.__dict__,
- )
- image, status_code, improved_image_prompt, intent_type = await text_to_image(
- q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
- )
- if image is None:
- content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
- return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
-
- await sync_to_async(save_to_conversation_log)(
- q,
- image,
- user,
- meta_log,
- user_message_time,
- intent_type=intent_type,
- inferred_queries=[improved_image_prompt],
- client_application=request.user.client_app,
- conversation_id=conversation.id,
- compiled_references=compiled_references,
- online_results=online_results,
- )
- content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore
- return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
-
- # Get the (streamed) chat response from the LLM of choice.
- llm_response, chat_metadata = await agenerate_chat_response(
- defiltered_query,
- meta_log,
- conversation,
- compiled_references,
- online_results,
- inferred_queries,
- conversation_commands,
- user,
- request.user.client_app,
- conversation.id,
- location,
- user_name,
- )
-
- cmd_set = set([cmd.value for cmd in conversation_commands])
- chat_metadata["conversation_command"] = cmd_set
- chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
-
- update_telemetry_state(
- request=request,
- telemetry_type="api",
- api="chat",
- metadata=chat_metadata,
- **common.__dict__,
- )
-
- if llm_response is None:
- return Response(content=llm_response, media_type="text/plain", status_code=500)
+ continue_stream = False
+ logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
+ ## Stream Text Response
if stream:
- return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
+ return StreamingResponse(event_generator(q), media_type="text/plain")
+ ## Non-Streaming Text Response
+ else:
+ # Get the full response from the generator if the stream is not requested.
+ response_obj = {}
+ actual_response = ""
+ iterator = event_generator(q)
+ async for item in iterator:
+ try:
+ item_json = json.loads(item)
+ if "type" in item_json and item_json["type"] == ChatEvent.REFERENCES.value:
+ response_obj = item_json["data"]
+ except:
+ actual_response += item
+ response_obj["response"] = actual_response
- iterator = AsyncIteratorWrapper(llm_response)
-
- # Get the full response from the generator if the stream is not requested.
- aggregated_gpt_response = ""
- async for item in iterator:
- if item is None:
- break
- aggregated_gpt_response += item
-
- actual_response = aggregated_gpt_response.split("### compiled references:")[0]
-
- response_obj = {
- "response": actual_response,
- "inferredQueries": inferred_queries,
- "context": compiled_references,
- "online_results": online_results,
- }
-
- return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
+ return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index 25d21f29..846f5c8f 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -8,6 +8,7 @@ import math
import re
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
+from enum import Enum
from functools import partial
from random import random
from typing import (
@@ -753,7 +754,7 @@ async def text_to_image(
references: List[Dict[str, Any]],
online_results: Dict[str, Any],
send_status_func: Optional[Callable] = None,
-) -> Tuple[Optional[str], int, Optional[str], str]:
+):
status_code = 200
image = None
response = None
@@ -765,7 +766,8 @@ async def text_to_image(
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
- return image_url or image, status_code, message, intent_type.value
+ yield image_url or image, status_code, message, intent_type.value
+ return
text2image_model = text_to_image_config.model_name
chat_history = ""
@@ -777,20 +779,21 @@ async def text_to_image(
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
- with timer("Improve the original user query", logger):
- if send_status_func:
- await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
- improved_image_prompt = await generate_better_image_prompt(
- message,
- chat_history,
- location_data=location_data,
- note_references=references,
- online_results=online_results,
- model_type=text_to_image_config.model_type,
- )
+ if send_status_func:
+ async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"):
+ yield {ChatEvent.STATUS: event}
+ improved_image_prompt = await generate_better_image_prompt(
+ message,
+ chat_history,
+ location_data=location_data,
+ note_references=references,
+ online_results=online_results,
+ model_type=text_to_image_config.model_type,
+ )
if send_status_func:
- await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
+ async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"):
+ yield {ChatEvent.STATUS: event}
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
with timer("Generate image with OpenAI", logger):
@@ -815,12 +818,14 @@ async def text_to_image(
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
- return image_url or image, status_code, message, intent_type.value
+ yield image_url or image, status_code, message, intent_type.value
+ return
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
- return image_url or image, status_code, message, intent_type.value
+ yield image_url or image, status_code, message, intent_type.value
+ return
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
with timer("Generate image with Stability AI", logger):
@@ -842,7 +847,8 @@ async def text_to_image(
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with Stability AI error: {e}"
status_code = e.status_code # type: ignore
- return image_url or image, status_code, message, intent_type.value
+ yield image_url or image, status_code, message, intent_type.value
+ return
with timer("Convert image to webp", logger):
# Convert png to webp for faster loading
@@ -862,7 +868,7 @@ async def text_to_image(
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8")
- return image_url or image, status_code, improved_image_prompt, intent_type.value
+ yield image_url or image, status_code, improved_image_prompt, intent_type.value
class ApiUserRateLimiter:
@@ -1184,3 +1190,11 @@ def construct_automation_created_message(automation: Job, crontime: str, query_t
Manage your automations [here](/automations).
""".strip()
+
+
+class ChatEvent(Enum):
+ START_LLM_RESPONSE = "start_llm_response"
+ END_LLM_RESPONSE = "end_llm_response"
+ MESSAGE = "message"
+ REFERENCES = "references"
+ STATUS = "status"
diff --git a/src/khoj/utils/fs_syncer.py b/src/khoj/utils/fs_syncer.py
index 5a20f418..3177d7ee 100644
--- a/src/khoj/utils/fs_syncer.py
+++ b/src/khoj/utils/fs_syncer.py
@@ -22,7 +22,7 @@ magika = Magika()
def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict:
- files = {}
+ files: dict[str, dict] = {"docx": {}, "image": {}}
if search_type == SearchType.All or search_type == SearchType.Org:
org_config = LocalOrgConfig.objects.filter(user=user).first()
diff --git a/tests/test_client.py b/tests/test_client.py
index 24d2dff6..c4246a78 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -455,13 +455,13 @@ def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiU
@pytest.mark.skipif(os.getenv("OPENAI_API_KEY") is None, reason="requires OPENAI_API_KEY")
@pytest.mark.django_db(transaction=True)
-def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser):
+async def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser):
# Arrange
headers = {"Authorization": f"Bearer {api_user2.token}"}
# Act
- auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true', headers=headers)
- no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true')
+ auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"', headers=headers)
+ no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"')
# Assert
assert auth_response.status_code == 200
diff --git a/tests/test_offline_chat_director.py b/tests/test_offline_chat_director.py
index 43e254e6..c0532c4b 100644
--- a/tests/test_offline_chat_director.py
+++ b/tests/test_offline_chat_director.py
@@ -67,10 +67,8 @@ def test_chat_with_online_content(client_offline_chat):
# Act
q = "/online give me the link to paul graham's essay how to do great work"
encoded_q = quote(q, safe="")
- response = client_offline_chat.get(f"/api/chat?q={encoded_q}&stream=true")
- response_message = response.content.decode("utf-8")
-
- response_message = response_message.split("### compiled references")[0]
+ response = client_offline_chat.get(f"/api/chat?q={encoded_q}")
+ response_message = response.json()["response"]
# Assert
expected_responses = [
@@ -91,10 +89,8 @@ def test_chat_with_online_webpage_content(client_offline_chat):
# Act
q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
encoded_q = quote(q, safe="")
- response = client_offline_chat.get(f"/api/chat?q={encoded_q}&stream=true")
- response_message = response.content.decode("utf-8")
-
- response_message = response_message.split("### compiled references")[0]
+ response = client_offline_chat.get(f"/api/chat?q={encoded_q}")
+ response_message = response.json()["response"]
# Assert
expected_responses = ["185", "1871", "horse"]
diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py
index 26d93d31..7a05a3dd 100644
--- a/tests/test_openai_chat_director.py
+++ b/tests/test_openai_chat_director.py
@@ -49,8 +49,8 @@ def create_conversation(message_list, user, agent=None):
@pytest.mark.django_db(transaction=True)
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# Act
- response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
- response_message = response.content.decode("utf-8")
+ response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"')
+ response_message = response.json()["response"]
# Assert
expected_responses = ["Khoj", "khoj"]
@@ -67,10 +67,8 @@ def test_chat_with_online_content(chat_client):
# Act
q = "/online give me the link to paul graham's essay how to do great work"
encoded_q = quote(q, safe="")
- response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true")
- response_message = response.content.decode("utf-8")
-
- response_message = response_message.split("### compiled references")[0]
+ response = chat_client.get(f"/api/chat?q={encoded_q}")
+ response_message = response.json()["response"]
# Assert
expected_responses = [
@@ -91,10 +89,8 @@ def test_chat_with_online_webpage_content(chat_client):
# Act
q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
encoded_q = quote(q, safe="")
- response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true")
- response_message = response.content.decode("utf-8")
-
- response_message = response_message.split("### compiled references")[0]
+ response = chat_client.get(f"/api/chat?q={encoded_q}")
+ response_message = response.json()["response"]
# Assert
expected_responses = ["185", "1871", "horse"]
@@ -144,7 +140,7 @@ def test_answer_from_currently_retrieved_content(chat_client, default_user2: Kho
# Act
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
- response_message = response.content.decode("utf-8")
+ response_message = response.json()["response"]
# Assert
assert response.status_code == 200
@@ -168,7 +164,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_n
# Act
response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"')
- response_message = response.content.decode("utf-8")
+ response_message = response.json()["response"]
# Assert
assert response.status_code == 200
@@ -191,7 +187,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, d
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"')
- response_message = response.content.decode("utf-8")
+ response_message = response.json()["response"]
# Assert
assert response.status_code == 200
@@ -215,8 +211,8 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use
create_conversation(message_list, default_user2)
# Act
- response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true')
- response_message = response.content.decode("utf-8")
+ response = chat_client.get(f'/api/chat?q="Where was I born?"')
+ response_message = response.json()["response"]
# Assert
expected_responses = [
@@ -226,6 +222,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use
"do not have",
"don't have",
"where were you born?",
+ "where you were born?",
]
assert response.status_code == 200
@@ -280,8 +277,8 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default
create_conversation(message_list, default_user2)
# Act
- response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true")
- response_message = response.content.decode("utf-8")
+ response = chat_client_no_background.get(f"/api/chat?q={query}")
+ response_message = response.json()["response"]
# Assert
assert response.status_code == 200
@@ -527,8 +524,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
create_conversation(message_list, default_user2)
# Act
- response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else."&stream=true')
- response_message = response.content.decode("utf-8").split("### compiled references")[0]
+ response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else.')
+ response_message = response.json()["response"]
# Assert
expected_responses = ["test", "Test"]
@@ -544,9 +541,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background):
# Act
-
- response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son?"&stream=true')
- response_message = response.content.decode("utf-8").split("### compiled references")[0].lower()
+ response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son?"')
+ response_message = response.json()["response"].lower()
# Assert
expected_responses = [
@@ -658,8 +654,8 @@ def test_answer_in_chat_history_by_conversation_id_with_agent(
def test_answer_requires_multiple_independent_searches(chat_client):
"Chat director should be able to answer by doing multiple independent searches for required information"
# Act
- response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"&stream=true')
- response_message = response.content.decode("utf-8").split("### compiled references")[0].lower()
+ response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"')
+ response_message = response.json()["response"].lower()
# Assert
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
@@ -683,8 +679,8 @@ def test_answer_using_file_filter(chat_client):
'Is Xi older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"'
)
- response = chat_client.get(f"/api/chat?q={query}&stream=true")
- response_message = response.content.decode("utf-8").split("### compiled references")[0].lower()
+ response = chat_client.get(f"/api/chat?q={query}")
+ response_message = response.json()["response"].lower()
# Assert
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]