Merge pull request #858 from khoj-ai/use-sse-instead-of-websocket

Use Single HTTP API for Robust, Generalizable Chat Streaming
This commit is contained in:
sabaimran 2024-07-26 07:11:54 -07:00 committed by GitHub
commit 377f7668c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 996 additions and 1261 deletions

View file

@ -61,6 +61,14 @@
let city = null; let city = null;
let countryName = null; let countryName = null;
let timezone = null; let timezone = null;
let chatMessageState = {
newResponseTextEl: null,
newResponseEl: null,
loadingEllipsis: null,
references: {},
rawResponse: "",
isVoice: false,
}
fetch("https://ipapi.co/json") fetch("https://ipapi.co/json")
.then(response => response.json()) .then(response => response.json())
@ -75,10 +83,9 @@
return; return;
}); });
async function chat() { async function chat(isVoice=false) {
// Extract required fields for search from form // Extract chat message from chat input form
let query = document.getElementById("chat-input").value.trim(); let query = document.getElementById("chat-input").value.trim();
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
console.log(`Query: ${query}`); console.log(`Query: ${query}`);
// Short circuit on empty query // Short circuit on empty query
@ -106,9 +113,6 @@
await refreshChatSessionsPanel(); await refreshChatSessionsPanel();
} }
// Generate backend API URL to execute query
let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}`;
let newResponseEl = document.createElement("div"); let newResponseEl = document.createElement("div");
newResponseEl.classList.add("chat-message", "khoj"); newResponseEl.classList.add("chat-message", "khoj");
newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
@ -119,25 +123,7 @@
newResponseEl.appendChild(newResponseTextEl); newResponseEl.appendChild(newResponseTextEl);
// Temporary status message to indicate that Khoj is thinking // Temporary status message to indicate that Khoj is thinking
let loadingEllipsis = document.createElement("div"); let loadingEllipsis = createLoadingEllipsis();
loadingEllipsis.classList.add("lds-ellipsis");
let firstEllipsis = document.createElement("div");
firstEllipsis.classList.add("lds-ellipsis-item");
let secondEllipsis = document.createElement("div");
secondEllipsis.classList.add("lds-ellipsis-item");
let thirdEllipsis = document.createElement("div");
thirdEllipsis.classList.add("lds-ellipsis-item");
let fourthEllipsis = document.createElement("div");
fourthEllipsis.classList.add("lds-ellipsis-item");
loadingEllipsis.appendChild(firstEllipsis);
loadingEllipsis.appendChild(secondEllipsis);
loadingEllipsis.appendChild(thirdEllipsis);
loadingEllipsis.appendChild(fourthEllipsis);
newResponseTextEl.appendChild(loadingEllipsis); newResponseTextEl.appendChild(loadingEllipsis);
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
@ -148,107 +134,36 @@
let chatInput = document.getElementById("chat-input"); let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled"); chatInput.classList.remove("option-enabled");
// Setup chat message state
chatMessageState = {
newResponseTextEl,
newResponseEl,
loadingEllipsis,
references: {},
rawResponse: "",
rawQuery: query,
isVoice: isVoice,
}
// Call Khoj chat API // Call Khoj chat API
let response = await fetch(chatApi, { headers }); let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationID}&stream=true&client=desktop`;
let rawResponse = ""; chatApi += (!!region && !!city && !!countryName && !!timezone)
let references = null; ? `&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}`
const contentType = response.headers.get("content-type"); : '';
const response = await fetch(chatApi, { headers });
if (contentType === "application/json") {
// Handle JSON response
try { try {
const responseAsJson = await response.json(); if (!response.ok) throw new Error(response.statusText);
if (responseAsJson.image) { if (!response.body) throw new Error("Response body is empty");
// If response has image field, response is a generated image. // Stream and render chat response
if (responseAsJson.intentType === "text-to-image") { await readChatStream(response);
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; } catch (err) {
} else if (responseAsJson.intentType === "text-to-image2") { console.error(`Khoj chat response failed with\n${err}`);
rawResponse += `![${query}](${responseAsJson.image})`; if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis)
} else if (responseAsJson.intentType === "text-to-image-v3") { chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
rawResponse += `![${query}](data:image/webp;base64,${responseAsJson.image})`; let errorMsg = "Sorry, unable to get response from Khoj backend ❤️‍🩹. Retry or contact developers for help at <a href=mailto:'team@khoj.dev'>team@khoj.dev</a> or <a href='https://discord.gg/BDgyabRM6e'>on Discord</a>";
} newResponseTextEl.textContent = errorMsg;
const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`;
}
}
if (responseAsJson.context) {
const rawReferenceAsJson = responseAsJson.context;
references = createReferenceSection(rawReferenceAsJson);
}
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseTextEl.innerHTML = "";
newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
if (references != null) {
newResponseTextEl.appendChild(references);
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
}
} else {
// Handle streamed response of type text/event-stream or text/plain
const reader = response.body.getReader();
const decoder = new TextDecoder();
let references = {};
readStream();
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
// Append any references after all the data has been streamed
if (references != {}) {
newResponseTextEl.appendChild(createReferenceSection(references));
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
return;
}
// Decode message chunk from stream
const chunk = decoder.decode(value, { stream: true });
if (chunk.includes("### compiled references:")) {
const additionalResponse = chunk.split("### compiled references:")[0];
rawResponse += additionalResponse;
newResponseTextEl.innerHTML = "";
newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
if (rawReferenceAsJson instanceof Array) {
references["notes"] = rawReferenceAsJson;
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
references["online"] = rawReferenceAsJson;
}
readStream();
} else {
// Display response from Khoj
if (newResponseTextEl.getElementsByClassName("lds-ellipsis").length > 0) {
newResponseTextEl.removeChild(loadingEllipsis);
}
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseTextEl.innerHTML = "";
newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
readStream();
}
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
});
}
} }
} }

View file

@ -364,3 +364,194 @@ function createReferenceSection(references, createLinkerSection=false) {
return referencesDiv; return referencesDiv;
} }
function createLoadingEllipsis() {
let loadingEllipsis = document.createElement("div");
loadingEllipsis.classList.add("lds-ellipsis");
let firstEllipsis = document.createElement("div");
firstEllipsis.classList.add("lds-ellipsis-item");
let secondEllipsis = document.createElement("div");
secondEllipsis.classList.add("lds-ellipsis-item");
let thirdEllipsis = document.createElement("div");
thirdEllipsis.classList.add("lds-ellipsis-item");
let fourthEllipsis = document.createElement("div");
fourthEllipsis.classList.add("lds-ellipsis-item");
loadingEllipsis.appendChild(firstEllipsis);
loadingEllipsis.appendChild(secondEllipsis);
loadingEllipsis.appendChild(thirdEllipsis);
loadingEllipsis.appendChild(fourthEllipsis);
return loadingEllipsis;
}
function handleStreamResponse(newResponseElement, rawResponse, rawQuery, loadingEllipsis, replace=true) {
if (!newResponseElement) return;
// Remove loading ellipsis if it exists
if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis)
newResponseElement.removeChild(loadingEllipsis);
// Clear the response element if replace is true
if (replace) newResponseElement.innerHTML = "";
// Append response to the response element
newResponseElement.appendChild(formatHTMLMessage(rawResponse, false, replace, rawQuery));
// Append loading ellipsis if it exists
if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis);
// Scroll to bottom of chat view
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
}
function handleImageResponse(imageJson, rawResponse) {
if (imageJson.image) {
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
// If response has image field, response is a generated image.
if (imageJson.intentType === "text-to-image") {
rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`;
} else if (imageJson.intentType === "text-to-image2") {
rawResponse += `![generated_image](${imageJson.image})`;
} else if (imageJson.intentType === "text-to-image-v3") {
rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
}
if (inferredQuery) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
}
// If response has detail field, response is an error message.
if (imageJson.detail) rawResponse += imageJson.detail;
return rawResponse;
}
function finalizeChatBodyResponse(references, newResponseElement) {
if (!!newResponseElement && references != null && Object.keys(references).length > 0) {
newResponseElement.appendChild(createReferenceSection(references));
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input")?.removeAttribute("disabled");
}
function convertMessageChunkToJson(rawChunk) {
// Split the chunk into lines
if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) {
try {
let jsonChunk = JSON.parse(rawChunk);
if (!jsonChunk.type)
jsonChunk = {type: 'message', data: jsonChunk};
return jsonChunk;
} catch (e) {
return {type: 'message', data: rawChunk};
}
} else if (rawChunk.length > 0) {
return {type: 'message', data: rawChunk};
}
}
function processMessageChunk(rawChunk) {
const chunk = convertMessageChunkToJson(rawChunk);
console.debug("Chunk:", chunk);
if (!chunk || !chunk.type) return;
if (chunk.type ==='status') {
console.log(`status: ${chunk.data}`);
const statusMessage = chunk.data;
handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, false);
} else if (chunk.type === 'start_llm_response') {
console.log("Started streaming", new Date());
} else if (chunk.type === 'end_llm_response') {
console.log("Stopped streaming", new Date());
// Automatically respond with voice if the subscribed user has sent voice message
if (chatMessageState.isVoice && "{{ is_active }}" == "True")
textToSpeech(chatMessageState.rawResponse);
// Append any references after all the data has been streamed
finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl);
const liveQuery = chatMessageState.rawQuery;
// Reset variables
chatMessageState = {
newResponseTextEl: null,
newResponseEl: null,
loadingEllipsis: null,
references: {},
rawResponse: "",
rawQuery: liveQuery,
isVoice: false,
}
} else if (chunk.type === "references") {
chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext};
} else if (chunk.type === 'message') {
const chunkData = chunk.data;
if (typeof chunkData === 'object' && chunkData !== null) {
// If chunkData is already a JSON object
handleJsonResponse(chunkData);
} else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) {
// Try process chunk data as if it is a JSON object
try {
const jsonData = JSON.parse(chunkData.trim());
handleJsonResponse(jsonData);
} catch (e) {
chatMessageState.rawResponse += chunkData;
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
}
} else {
chatMessageState.rawResponse += chunkData;
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
}
}
}
function handleJsonResponse(jsonData) {
if (jsonData.image || jsonData.detail) {
chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse);
} else if (jsonData.response) {
chatMessageState.rawResponse = jsonData.response;
}
if (chatMessageState.newResponseTextEl) {
chatMessageState.newResponseTextEl.innerHTML = "";
chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse));
}
}
async function readChatStream(response) {
if (!response.body) return;
const reader = response.body.getReader();
const decoder = new TextDecoder();
const eventDelimiter = '␃🔚␗';
let buffer = '';
while (true) {
const { value, done } = await reader.read();
// If the stream is done
if (done) {
// Process the last chunk
processMessageChunk(buffer);
buffer = '';
break;
}
// Read chunk from stream and append it to the buffer
const chunk = decoder.decode(value, { stream: true });
console.debug("Raw Chunk:", chunk)
// Start buffering chunks until complete event is received
buffer += chunk;
// Once the buffer contains a complete event
let newEventIndex;
while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
// Extract the event from the buffer
const event = buffer.slice(0, newEventIndex);
buffer = buffer.slice(newEventIndex + eventDelimiter.length);
// Process the event
if (event) processMessageChunk(event);
}
}
}

View file

@ -346,7 +346,7 @@
inp.focus(); inp.focus();
} }
async function chat() { async function chat(isVoice=false) {
//set chat body to empty //set chat body to empty
let chatBody = document.getElementById("chat-body"); let chatBody = document.getElementById("chat-body");
chatBody.innerHTML = ""; chatBody.innerHTML = "";
@ -375,9 +375,6 @@
chat_body.dataset.conversationId = conversationID; chat_body.dataset.conversationId = conversationID;
} }
// Generate backend API URL to execute query
let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}`;
let newResponseEl = document.createElement("div"); let newResponseEl = document.createElement("div");
newResponseEl.classList.add("chat-message", "khoj"); newResponseEl.classList.add("chat-message", "khoj");
newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
@ -388,128 +385,41 @@
newResponseEl.appendChild(newResponseTextEl); newResponseEl.appendChild(newResponseTextEl);
// Temporary status message to indicate that Khoj is thinking // Temporary status message to indicate that Khoj is thinking
let loadingEllipsis = document.createElement("div"); let loadingEllipsis = createLoadingEllipsis();
loadingEllipsis.classList.add("lds-ellipsis");
let firstEllipsis = document.createElement("div");
firstEllipsis.classList.add("lds-ellipsis-item");
let secondEllipsis = document.createElement("div");
secondEllipsis.classList.add("lds-ellipsis-item");
let thirdEllipsis = document.createElement("div");
thirdEllipsis.classList.add("lds-ellipsis-item");
let fourthEllipsis = document.createElement("div");
fourthEllipsis.classList.add("lds-ellipsis-item");
loadingEllipsis.appendChild(firstEllipsis);
loadingEllipsis.appendChild(secondEllipsis);
loadingEllipsis.appendChild(thirdEllipsis);
loadingEllipsis.appendChild(fourthEllipsis);
newResponseTextEl.appendChild(loadingEllipsis);
document.body.scrollTop = document.getElementById("chat-body").scrollHeight; document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
// Call Khoj chat API
let response = await fetch(chatApi, { headers });
let rawResponse = "";
let references = null;
const contentType = response.headers.get("content-type");
toggleLoading(); toggleLoading();
if (contentType === "application/json") {
// Handle JSON response // Setup chat message state
chatMessageState = {
newResponseTextEl,
newResponseEl,
loadingEllipsis,
references: {},
rawResponse: "",
rawQuery: query,
isVoice: isVoice,
}
// Construct API URL to execute chat query
let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationID}&stream=true&client=desktop`;
chatApi += (!!region && !!city && !!countryName && !!timezone)
? `&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}`
: '';
const response = await fetch(chatApi, { headers });
try { try {
const responseAsJson = await response.json(); if (!response.ok) throw new Error(response.statusText);
if (responseAsJson.image) { if (!response.body) throw new Error("Response body is empty");
// If response has image field, response is a generated image. // Stream and render chat response
if (responseAsJson.intentType === "text-to-image") { await readChatStream(response);
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; } catch (err) {
} else if (responseAsJson.intentType === "text-to-image2") { console.error(`Khoj chat response failed with\n${err}`);
rawResponse += `![${query}](${responseAsJson.image})`; if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis)
} else if (responseAsJson.intentType === "text-to-image-v3") { chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
rawResponse += `![${query}](data:image/webp;base64,${responseAsJson.image})`; let errorMsg = "Sorry, unable to get response from Khoj backend ❤️‍🩹. Retry or contact developers for help at <a href=mailto:'team@khoj.dev'>team@khoj.dev</a> or <a href='https://discord.gg/BDgyabRM6e'>on Discord</a>";
} newResponseTextEl.textContent = errorMsg;
const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`;
}
}
if (responseAsJson.context) {
const rawReferenceAsJson = responseAsJson.context;
references = createReferenceSection(rawReferenceAsJson, createLinkerSection=true);
}
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseTextEl.innerHTML = "";
newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
if (references != null) {
newResponseTextEl.appendChild(references);
}
document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
}
} else {
// Handle streamed response of type text/event-stream or text/plain
const reader = response.body.getReader();
const decoder = new TextDecoder();
let references = {};
readStream();
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
// Append any references after all the data has been streamed
if (references != {}) {
newResponseTextEl.appendChild(createReferenceSection(references, createLinkerSection=true));
}
document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
return;
}
// Decode message chunk from stream
const chunk = decoder.decode(value, { stream: true });
if (chunk.includes("### compiled references:")) {
const additionalResponse = chunk.split("### compiled references:")[0];
rawResponse += additionalResponse;
newResponseTextEl.innerHTML = "";
newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
if (rawReferenceAsJson instanceof Array) {
references["notes"] = rawReferenceAsJson;
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
references["online"] = rawReferenceAsJson;
}
readStream();
} else {
// Display response from Khoj
if (newResponseTextEl.getElementsByClassName("lds-ellipsis").length > 0) {
newResponseTextEl.removeChild(loadingEllipsis);
}
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseTextEl.innerHTML = "";
newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
readStream();
}
// Scroll to bottom of chat window as chat response is streamed
document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
});
}
} }
document.body.scrollTop = document.getElementById("chat-body").scrollHeight; document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
} }

View file

@ -34,8 +34,8 @@ function toggleNavMenu() {
document.addEventListener('click', function(event) { document.addEventListener('click', function(event) {
let menu = document.getElementById("khoj-nav-menu"); let menu = document.getElementById("khoj-nav-menu");
let menuContainer = document.getElementById("khoj-nav-menu-container"); let menuContainer = document.getElementById("khoj-nav-menu-container");
let isClickOnMenu = menuContainer.contains(event.target) || menuContainer === event.target; let isClickOnMenu = menuContainer?.contains(event.target) || menuContainer === event.target;
if (isClickOnMenu === false && menu.classList.contains("show")) { if (menu && isClickOnMenu === false && menu.classList.contains("show")) {
menu.classList.remove("show"); menu.classList.remove("show");
} }
}); });

View file

@ -12,6 +12,25 @@ export interface ChatJsonResult {
inferredQueries?: string[]; inferredQueries?: string[];
} }
interface ChunkResult {
objects: string[];
remainder: string;
}
interface MessageChunk {
type: string;
data: any;
}
interface ChatMessageState {
newResponseTextEl: HTMLElement | null;
newResponseEl: HTMLElement | null;
loadingEllipsis: HTMLElement | null;
references: any;
rawResponse: string;
rawQuery: string;
isVoice: boolean;
}
interface Location { interface Location {
region: string; region: string;
@ -26,6 +45,7 @@ export class KhojChatView extends KhojPaneView {
waitingForLocation: boolean; waitingForLocation: boolean;
location: Location; location: Location;
keyPressTimeout: NodeJS.Timeout | null = null; keyPressTimeout: NodeJS.Timeout | null = null;
chatMessageState: ChatMessageState;
constructor(leaf: WorkspaceLeaf, setting: KhojSetting) { constructor(leaf: WorkspaceLeaf, setting: KhojSetting) {
super(leaf, setting); super(leaf, setting);
@ -409,16 +429,15 @@ export class KhojChatView extends KhojPaneView {
message = DOMPurify.sanitize(message); message = DOMPurify.sanitize(message);
// Convert the message to html, sanitize the message html and render it to the real DOM // Convert the message to html, sanitize the message html and render it to the real DOM
let chat_message_body_text_el = this.contentEl.createDiv(); let chatMessageBodyTextEl = this.contentEl.createDiv();
chat_message_body_text_el.className = "chat-message-text-response"; chatMessageBodyTextEl.innerHTML = this.markdownTextToSanitizedHtml(message, this);
chat_message_body_text_el.innerHTML = this.markdownTextToSanitizedHtml(message, this);
// Add a copy button to each chat message, if it doesn't already exist // Add a copy button to each chat message, if it doesn't already exist
if (willReplace === true) { if (willReplace === true) {
this.renderActionButtons(message, chat_message_body_text_el); this.renderActionButtons(message, chatMessageBodyTextEl);
} }
return chat_message_body_text_el; return chatMessageBodyTextEl;
} }
markdownTextToSanitizedHtml(markdownText: string, component: ItemView): string { markdownTextToSanitizedHtml(markdownText: string, component: ItemView): string {
@ -502,23 +521,23 @@ export class KhojChatView extends KhojPaneView {
class: `khoj-chat-message ${sender}` class: `khoj-chat-message ${sender}`
}, },
}) })
let chat_message_body_el = chatMessageEl.createDiv(); let chatMessageBodyEl = chatMessageEl.createDiv();
chat_message_body_el.addClasses(["khoj-chat-message-text", sender]); chatMessageBodyEl.addClasses(["khoj-chat-message-text", sender]);
let chat_message_body_text_el = chat_message_body_el.createDiv(); let chatMessageBodyTextEl = chatMessageBodyEl.createDiv();
// Sanitize the markdown to render // Sanitize the markdown to render
message = DOMPurify.sanitize(message); message = DOMPurify.sanitize(message);
if (raw) { if (raw) {
chat_message_body_text_el.innerHTML = message; chatMessageBodyTextEl.innerHTML = message;
} else { } else {
// @ts-ignore // @ts-ignore
chat_message_body_text_el.innerHTML = this.markdownTextToSanitizedHtml(message, this); chatMessageBodyTextEl.innerHTML = this.markdownTextToSanitizedHtml(message, this);
} }
// Add action buttons to each chat message element // Add action buttons to each chat message element
if (willReplace === true) { if (willReplace === true) {
this.renderActionButtons(message, chat_message_body_text_el); this.renderActionButtons(message, chatMessageBodyTextEl);
} }
// Remove user-select: none property to make text selectable // Remove user-select: none property to make text selectable
@ -531,42 +550,38 @@ export class KhojChatView extends KhojPaneView {
} }
createKhojResponseDiv(dt?: Date): HTMLDivElement { createKhojResponseDiv(dt?: Date): HTMLDivElement {
let message_time = this.formatDate(dt ?? new Date()); let messageTime = this.formatDate(dt ?? new Date());
// Append message to conversation history HTML element. // Append message to conversation history HTML element.
// The chat logs should display above the message input box to follow standard UI semantics // The chat logs should display above the message input box to follow standard UI semantics
let chat_body_el = this.contentEl.getElementsByClassName("khoj-chat-body")[0]; let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0];
let chat_message_el = chat_body_el.createDiv({ let chatMessageEl = chatBodyEl.createDiv({
attr: { attr: {
"data-meta": `🏮 Khoj at ${message_time}`, "data-meta": `🏮 Khoj at ${messageTime}`,
class: `khoj-chat-message khoj` class: `khoj-chat-message khoj`
}, },
}).createDiv({ })
attr: {
class: `khoj-chat-message-text khoj`
},
}).createDiv();
// Scroll to bottom after inserting chat messages // Scroll to bottom after inserting chat messages
this.scrollChatToBottom(); this.scrollChatToBottom();
return chat_message_el; return chatMessageEl;
} }
async renderIncrementalMessage(htmlElement: HTMLDivElement, additionalMessage: string) { async renderIncrementalMessage(htmlElement: HTMLDivElement, additionalMessage: string) {
this.result += additionalMessage; this.chatMessageState.rawResponse += additionalMessage;
htmlElement.innerHTML = ""; htmlElement.innerHTML = "";
// Sanitize the markdown to render // Sanitize the markdown to render
this.result = DOMPurify.sanitize(this.result); this.chatMessageState.rawResponse = DOMPurify.sanitize(this.chatMessageState.rawResponse);
// @ts-ignore // @ts-ignore
htmlElement.innerHTML = this.markdownTextToSanitizedHtml(this.result, this); htmlElement.innerHTML = this.markdownTextToSanitizedHtml(this.chatMessageState.rawResponse, this);
// Render action buttons for the message // Render action buttons for the message
this.renderActionButtons(this.result, htmlElement); this.renderActionButtons(this.chatMessageState.rawResponse, htmlElement);
// Scroll to bottom of modal, till the send message input box // Scroll to bottom of modal, till the send message input box
this.scrollChatToBottom(); this.scrollChatToBottom();
} }
renderActionButtons(message: string, chat_message_body_text_el: HTMLElement) { renderActionButtons(message: string, chatMessageBodyTextEl: HTMLElement) {
let copyButton = this.contentEl.createEl('button'); let copyButton = this.contentEl.createEl('button');
copyButton.classList.add("chat-action-button"); copyButton.classList.add("chat-action-button");
copyButton.title = "Copy Message to Clipboard"; copyButton.title = "Copy Message to Clipboard";
@ -593,10 +608,10 @@ export class KhojChatView extends KhojPaneView {
} }
// Append buttons to parent element // Append buttons to parent element
chat_message_body_text_el.append(copyButton, pasteToFile); chatMessageBodyTextEl.append(copyButton, pasteToFile);
if (speechButton) { if (speechButton) {
chat_message_body_text_el.append(speechButton); chatMessageBodyTextEl.append(speechButton);
} }
} }
@ -854,35 +869,122 @@ export class KhojChatView extends KhojPaneView {
return true; return true;
} }
async readChatStream(response: Response, responseElement: HTMLDivElement, isVoice: boolean = false): Promise<void> { convertMessageChunkToJson(rawChunk: string): MessageChunk {
if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) {
try {
let jsonChunk = JSON.parse(rawChunk);
if (!jsonChunk.type)
jsonChunk = {type: 'message', data: jsonChunk};
return jsonChunk;
} catch (e) {
return {type: 'message', data: rawChunk};
}
} else if (rawChunk.length > 0) {
return {type: 'message', data: rawChunk};
}
return {type: '', data: ''};
}
processMessageChunk(rawChunk: string): void {
const chunk = this.convertMessageChunkToJson(rawChunk);
console.debug("Chunk:", chunk);
if (!chunk || !chunk.type) return;
if (chunk.type === 'status') {
console.log(`status: ${chunk.data}`);
const statusMessage = chunk.data;
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, statusMessage, this.chatMessageState.loadingEllipsis, false);
} else if (chunk.type === 'start_llm_response') {
console.log("Started streaming", new Date());
} else if (chunk.type === 'end_llm_response') {
console.log("Stopped streaming", new Date());
// Automatically respond with voice if the subscribed user has sent voice message
if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active)
this.textToSpeech(this.chatMessageState.rawResponse);
// Append any references after all the data has been streamed
this.finalizeChatBodyResponse(this.chatMessageState.references, this.chatMessageState.newResponseTextEl);
const liveQuery = this.chatMessageState.rawQuery;
// Reset variables
this.chatMessageState = {
newResponseTextEl: null,
newResponseEl: null,
loadingEllipsis: null,
references: {},
rawResponse: "",
rawQuery: liveQuery,
isVoice: false,
};
} else if (chunk.type === "references") {
this.chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext};
} else if (chunk.type === 'message') {
const chunkData = chunk.data;
if (typeof chunkData === 'object' && chunkData !== null) {
// If chunkData is already a JSON object
this.handleJsonResponse(chunkData);
} else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) {
// Try process chunk data as if it is a JSON object
try {
const jsonData = JSON.parse(chunkData.trim());
this.handleJsonResponse(jsonData);
} catch (e) {
this.chatMessageState.rawResponse += chunkData;
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis);
}
} else {
this.chatMessageState.rawResponse += chunkData;
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis);
}
}
}
handleJsonResponse(jsonData: any): void {
if (jsonData.image || jsonData.detail) {
this.chatMessageState.rawResponse = this.handleImageResponse(jsonData, this.chatMessageState.rawResponse);
} else if (jsonData.response) {
this.chatMessageState.rawResponse = jsonData.response;
}
if (this.chatMessageState.newResponseTextEl) {
this.chatMessageState.newResponseTextEl.innerHTML = "";
this.chatMessageState.newResponseTextEl.appendChild(this.formatHTMLMessage(this.chatMessageState.rawResponse));
}
}
async readChatStream(response: Response): Promise<void> {
// Exit if response body is empty // Exit if response body is empty
if (response.body == null) return; if (response.body == null) return;
const reader = response.body.getReader(); const reader = response.body.getReader();
const decoder = new TextDecoder(); const decoder = new TextDecoder();
const eventDelimiter = '␃🔚␗';
let buffer = '';
while (true) { while (true) {
const { value, done } = await reader.read(); const { value, done } = await reader.read();
if (done) { if (done) {
// Automatically respond with voice if the subscribed user has sent voice message this.processMessageChunk(buffer);
if (isVoice && this.setting.userInfo?.is_active) this.textToSpeech(this.result); buffer = '';
// Break if the stream is done // Break if the stream is done
break; break;
} }
let responseText = decoder.decode(value); const chunk = decoder.decode(value, { stream: true });
if (responseText.includes("### compiled references:")) { console.debug("Raw Chunk:", chunk)
// Render any references used to generate the response // Start buffering chunks until complete event is received
const [additionalResponse, rawReference] = responseText.split("### compiled references:", 2); buffer += chunk;
await this.renderIncrementalMessage(responseElement, additionalResponse);
const rawReferenceAsJson = JSON.parse(rawReference); // Once the buffer contains a complete event
let references = this.extractReferences(rawReferenceAsJson); let newEventIndex;
responseElement.appendChild(this.createReferenceSection(references)); while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
} else { // Extract the event from the buffer
// Render incremental chat response const event = buffer.slice(0, newEventIndex);
await this.renderIncrementalMessage(responseElement, responseText); buffer = buffer.slice(newEventIndex + eventDelimiter.length);
// Process the event
if (event) this.processMessageChunk(event);
} }
} }
} }
@ -895,83 +997,59 @@ export class KhojChatView extends KhojPaneView {
let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0] as HTMLElement; let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0] as HTMLElement;
this.renderMessage(chatBodyEl, query, "you"); this.renderMessage(chatBodyEl, query, "you");
let conversationID = chatBodyEl.dataset.conversationId; let conversationId = chatBodyEl.dataset.conversationId;
if (!conversationID) { if (!conversationId) {
let chatUrl = `${this.setting.khojUrl}/api/chat/sessions?client=obsidian`; let chatUrl = `${this.setting.khojUrl}/api/chat/sessions?client=obsidian`;
let response = await fetch(chatUrl, { let response = await fetch(chatUrl, {
method: "POST", method: "POST",
headers: { "Authorization": `Bearer ${this.setting.khojApiKey}` }, headers: { "Authorization": `Bearer ${this.setting.khojApiKey}` },
}); });
let data = await response.json(); let data = await response.json();
conversationID = data.conversation_id; conversationId = data.conversation_id;
chatBodyEl.dataset.conversationId = conversationID; chatBodyEl.dataset.conversationId = conversationId;
} }
// Get chat response from Khoj backend // Get chat response from Khoj backend
let encodedQuery = encodeURIComponent(query); let encodedQuery = encodeURIComponent(query);
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian&stream=true&region=${this.location.region}&city=${this.location.city}&country=${this.location.countryName}&timezone=${this.location.timezone}`; let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&conversation_id=${conversationId}&n=${this.setting.resultsCount}&stream=true&client=obsidian`;
let responseElement = this.createKhojResponseDiv(); if (!!this.location) chatUrl += `&region=${this.location.region}&city=${this.location.city}&country=${this.location.countryName}&timezone=${this.location.timezone}`;
let newResponseEl = this.createKhojResponseDiv();
let newResponseTextEl = newResponseEl.createDiv();
newResponseTextEl.classList.add("khoj-chat-message-text", "khoj");
// Temporary status message to indicate that Khoj is thinking // Temporary status message to indicate that Khoj is thinking
this.result = "";
let loadingEllipsis = this.createLoadingEllipse(); let loadingEllipsis = this.createLoadingEllipse();
responseElement.appendChild(loadingEllipsis); newResponseTextEl.appendChild(loadingEllipsis);
// Set chat message state
this.chatMessageState = {
newResponseEl: newResponseEl,
newResponseTextEl: newResponseTextEl,
loadingEllipsis: loadingEllipsis,
references: {},
rawQuery: query,
rawResponse: "",
isVoice: isVoice,
};
let response = await fetch(chatUrl, { let response = await fetch(chatUrl, {
method: "GET", method: "GET",
headers: { headers: {
"Content-Type": "text/event-stream", "Content-Type": "text/plain",
"Authorization": `Bearer ${this.setting.khojApiKey}`, "Authorization": `Bearer ${this.setting.khojApiKey}`,
}, },
}) })
try { try {
if (response.body === null) { if (response.body === null) throw new Error("Response body is null");
throw new Error("Response body is null");
}
// Clear loading status message
if (responseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) {
responseElement.removeChild(loadingEllipsis);
}
// Reset collated chat result to empty string
this.result = "";
responseElement.innerHTML = "";
if (response.headers.get("content-type") === "application/json") {
let responseText = ""
try {
const responseAsJson = await response.json() as ChatJsonResult;
if (responseAsJson.image) {
// If response has image field, response is a generated image.
if (responseAsJson.intentType === "text-to-image") {
responseText += `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image2") {
responseText += `![${query}](${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image-v3") {
responseText += `![${query}](data:image/webp;base64,${responseAsJson.image})`;
}
const inferredQuery = responseAsJson.inferredQueries?.[0];
if (inferredQuery) {
responseText += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
} else if (responseAsJson.detail) {
responseText = responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
responseText = await response.text();
} finally {
await this.renderIncrementalMessage(responseElement, responseText);
}
} else {
// Stream and render chat response // Stream and render chat response
await this.readChatStream(response, responseElement, isVoice); await this.readChatStream(response);
}
} catch (err) { } catch (err) {
console.log(`Khoj chat response failed with\n${err}`); console.error(`Khoj chat response failed with\n${err}`);
let errorMsg = "Sorry, unable to get response from Khoj backend ❤️‍🩹. Retry or contact developers for help at <a href=mailto:'team@khoj.dev'>team@khoj.dev</a> or <a href='https://discord.gg/BDgyabRM6e'>on Discord</a>"; let errorMsg = "Sorry, unable to get response from Khoj backend ❤️‍🩹. Retry or contact developers for help at <a href=mailto:'team@khoj.dev'>team@khoj.dev</a> or <a href='https://discord.gg/BDgyabRM6e'>on Discord</a>";
responseElement.innerHTML = errorMsg newResponseTextEl.textContent = errorMsg;
} }
} }
@ -1196,30 +1274,21 @@ export class KhojChatView extends KhojPaneView {
handleStreamResponse(newResponseElement: HTMLElement | null, rawResponse: string, loadingEllipsis: HTMLElement | null, replace = true) { handleStreamResponse(newResponseElement: HTMLElement | null, rawResponse: string, loadingEllipsis: HTMLElement | null, replace = true) {
if (!newResponseElement) return; if (!newResponseElement) return;
if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) { // Remove loading ellipsis if it exists
if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis)
newResponseElement.removeChild(loadingEllipsis); newResponseElement.removeChild(loadingEllipsis);
} // Clear the response element if replace is true
if (replace) { if (replace) newResponseElement.innerHTML = "";
newResponseElement.innerHTML = "";
} // Append response to the response element
newResponseElement.appendChild(this.formatHTMLMessage(rawResponse, false, replace)); newResponseElement.appendChild(this.formatHTMLMessage(rawResponse, false, replace));
// Append loading ellipsis if it exists
if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis);
// Scroll to bottom of chat view
this.scrollChatToBottom(); this.scrollChatToBottom();
} }
handleCompiledReferences(rawResponseElement: HTMLElement | null, chunk: string, references: any, rawResponse: string) {
if (!rawResponseElement || !chunk) return { rawResponse, references };
const [additionalResponse, rawReference] = chunk.split("### compiled references:", 2);
rawResponse += additionalResponse;
rawResponseElement.innerHTML = "";
rawResponseElement.appendChild(this.formatHTMLMessage(rawResponse));
const rawReferenceAsJson = JSON.parse(rawReference);
references = this.extractReferences(rawReferenceAsJson);
return { rawResponse, references };
}
handleImageResponse(imageJson: any, rawResponse: string) { handleImageResponse(imageJson: any, rawResponse: string) {
if (imageJson.image) { if (imageJson.image) {
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image"; const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
@ -1236,33 +1305,10 @@ export class KhojChatView extends KhojPaneView {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
} }
} }
let references = {};
if (imageJson.context && imageJson.context.length > 0) {
references = this.extractReferences(imageJson.context);
}
if (imageJson.detail) {
// If response has detail field, response is an error message. // If response has detail field, response is an error message.
rawResponse += imageJson.detail; if (imageJson.detail) rawResponse += imageJson.detail;
}
return { rawResponse, references };
}
extractReferences(rawReferenceAsJson: any): object { return rawResponse;
let references: any = {};
if (rawReferenceAsJson instanceof Array) {
references["notes"] = rawReferenceAsJson;
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
references["online"] = rawReferenceAsJson;
}
return references;
}
addMessageToChatBody(rawResponse: string, newResponseElement: HTMLElement | null, references: any) {
if (!newResponseElement) return;
newResponseElement.innerHTML = "";
newResponseElement.appendChild(this.formatHTMLMessage(rawResponse));
this.finalizeChatBodyResponse(references, newResponseElement);
} }
finalizeChatBodyResponse(references: object, newResponseElement: HTMLElement | null) { finalizeChatBodyResponse(references: object, newResponseElement: HTMLElement | null) {

View file

@ -85,6 +85,12 @@ If your plugin does not need CSS, delete this file.
margin-left: auto; margin-left: auto;
white-space: pre-line; white-space: pre-line;
} }
/* Override white-space for ul, ol, li under khoj-chat-message-text.khoj */
.khoj-chat-message-text.khoj ul,
.khoj-chat-message-text.khoj ol,
.khoj-chat-message-text.khoj li {
white-space: normal;
}
/* add left protrusion to khoj chat bubble */ /* add left protrusion to khoj chat bubble */
.khoj-chat-message-text.khoj:after { .khoj-chat-message-text.khoj:after {
content: ''; content: '';

View file

@ -680,34 +680,18 @@ class ConversationAdapters:
async def aget_conversation_by_user( async def aget_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, title: str = None user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, title: str = None
) -> Optional[Conversation]: ) -> Optional[Conversation]:
query = Conversation.objects.filter(user=user, client=client_application).prefetch_related("agent")
if conversation_id: if conversation_id:
return ( return await query.filter(id=conversation_id).afirst()
await Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
.prefetch_related("agent")
.afirst()
)
elif title: elif title:
return ( return await query.filter(title=title).afirst()
await Conversation.objects.filter(user=user, client=client_application, title=title)
.prefetch_related("agent")
.afirst()
)
else:
conversation = (
Conversation.objects.filter(user=user, client=client_application)
.prefetch_related("agent")
.order_by("-updated_at")
)
if await conversation.aexists(): conversation = await query.order_by("-updated_at").afirst()
return await conversation.prefetch_related("agent").afirst()
return await ( return conversation or await Conversation.objects.prefetch_related("agent").acreate(
Conversation.objects.filter(user=user, client=client_application) user=user, client=client_application
.prefetch_related("agent") )
.order_by("-updated_at")
.afirst()
) or await Conversation.objects.prefetch_related("agent").acreate(user=user, client=client_application)
@staticmethod @staticmethod
async def adelete_conversation_by_user( async def adelete_conversation_by_user(

View file

@ -74,14 +74,13 @@ To get started, just start typing below. You can also type / to see a list of co
}, 1000); }, 1000);
}); });
} }
var websocket = null;
let region = null; let region = null;
let city = null; let city = null;
let countryName = null; let countryName = null;
let timezone = null; let timezone = null;
let waitingForLocation = true; let waitingForLocation = true;
let chatMessageState = {
let websocketState = {
newResponseTextEl: null, newResponseTextEl: null,
newResponseEl: null, newResponseEl: null,
loadingEllipsis: null, loadingEllipsis: null,
@ -105,7 +104,7 @@ To get started, just start typing below. You can also type / to see a list of co
.finally(() => { .finally(() => {
console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone); console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone);
waitingForLocation = false; waitingForLocation = false;
setupWebSocket(); initMessageState();
}); });
function formatDate(date) { function formatDate(date) {
@ -599,13 +598,8 @@ To get started, just start typing below. You can also type / to see a list of co
} }
async function chat(isVoice=false) { async function chat(isVoice=false) {
if (websocket) { // Extract chat message from chat input form
sendMessageViaWebSocket(isVoice); var query = document.getElementById("chat-input").value.trim();
return;
}
let query = document.getElementById("chat-input").value.trim();
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
console.log(`Query: ${query}`); console.log(`Query: ${query}`);
// Short circuit on empty query // Short circuit on empty query
@ -624,31 +618,30 @@ To get started, just start typing below. You can also type / to see a list of co
document.getElementById("chat-input").value = ""; document.getElementById("chat-input").value = "";
autoResize(); autoResize();
document.getElementById("chat-input").setAttribute("disabled", "disabled"); document.getElementById("chat-input").setAttribute("disabled", "disabled");
let chat_body = document.getElementById("chat-body");
let conversationID = chat_body.dataset.conversationId;
let chatBody = document.getElementById("chat-body");
let conversationID = chatBody.dataset.conversationId;
if (!conversationID) { if (!conversationID) {
let response = await fetch('/api/chat/sessions', { method: "POST" }); let response = await fetch(`${hostURL}/api/chat/sessions`, { method: "POST" });
let data = await response.json(); let data = await response.json();
conversationID = data.conversation_id; conversationID = data.conversation_id;
chat_body.dataset.conversationId = conversationID; chatBody.dataset.conversationId = conversationID;
refreshChatSessionsPanel(); await refreshChatSessionsPanel();
} }
let new_response = document.createElement("div"); let newResponseEl = document.createElement("div");
new_response.classList.add("chat-message", "khoj"); newResponseEl.classList.add("chat-message", "khoj");
new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
chat_body.appendChild(new_response); chatBody.appendChild(newResponseEl);
let newResponseText = document.createElement("div"); let newResponseTextEl = document.createElement("div");
newResponseText.classList.add("chat-message-text", "khoj"); newResponseTextEl.classList.add("chat-message-text", "khoj");
new_response.appendChild(newResponseText); newResponseEl.appendChild(newResponseTextEl);
// Temporary status message to indicate that Khoj is thinking // Temporary status message to indicate that Khoj is thinking
let loadingEllipsis = createLoadingEllipse(); let loadingEllipsis = createLoadingEllipse();
newResponseText.appendChild(loadingEllipsis); newResponseTextEl.appendChild(loadingEllipsis);
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
let chatTooltip = document.getElementById("chat-tooltip"); let chatTooltip = document.getElementById("chat-tooltip");
@ -657,65 +650,38 @@ To get started, just start typing below. You can also type / to see a list of co
let chatInput = document.getElementById("chat-input"); let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled"); chatInput.classList.remove("option-enabled");
// Generate backend API URL to execute query // Setup chat message state
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}`; chatMessageState = {
newResponseTextEl,
newResponseEl,
loadingEllipsis,
references: {},
rawResponse: "",
rawQuery: query,
isVoice: isVoice,
}
// Call specified Khoj API // Call Khoj chat API
let response = await fetch(url); let chatApi = `/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationID}&stream=true&client=web`;
let rawResponse = ""; chatApi += (!!region && !!city && !!countryName && !!timezone)
let references = null; ? `&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}`
const contentType = response.headers.get("content-type"); : '';
const response = await fetch(chatApi);
if (contentType === "application/json") {
// Handle JSON response
try { try {
const responseAsJson = await response.json(); if (!response.ok) throw new Error(response.statusText);
if (responseAsJson.image || responseAsJson.detail) { if (!response.body) throw new Error("Response body is empty");
({rawResponse, references } = handleImageResponse(responseAsJson, rawResponse)); // Stream and render chat response
} else { await readChatStream(response);
rawResponse = responseAsJson.response; } catch (err) {
console.error(`Khoj chat response failed with\n${err}`);
if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis)
chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
let errorMsg = "Sorry, unable to get response from Khoj backend ❤️‍🩹. Retry or contact developers for help at <a href=mailto:'team@khoj.dev'>team@khoj.dev</a> or <a href='https://discord.gg/BDgyabRM6e'>on Discord</a>";
newResponseTextEl.innerHTML = errorMsg;
} }
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
addMessageToChatBody(rawResponse, newResponseText, references);
} }
} else {
// Handle streamed response of type text/event-stream or text/plain
const reader = response.body.getReader();
const decoder = new TextDecoder();
let references = {};
readStream();
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
// Append any references after all the data has been streamed
finalizeChatBodyResponse(references, newResponseText);
return;
}
// Decode message chunk from stream
const chunk = decoder.decode(value, { stream: true });
if (chunk.includes("### compiled references:")) {
({ rawResponse, references } = handleCompiledReferences(newResponseText, chunk, references, rawResponse));
readStream();
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
handleStreamResponse(newResponseText, rawResponse, query, loadingEllipsis);
readStream();
}
});
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
};
}
};
function createLoadingEllipse() { function createLoadingEllipse() {
// Temporary status message to indicate that Khoj is thinking // Temporary status message to indicate that Khoj is thinking
@ -743,32 +709,22 @@ To get started, just start typing below. You can also type / to see a list of co
} }
function handleStreamResponse(newResponseElement, rawResponse, rawQuery, loadingEllipsis, replace=true) { function handleStreamResponse(newResponseElement, rawResponse, rawQuery, loadingEllipsis, replace=true) {
if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) { if (!newResponseElement) return;
// Remove loading ellipsis if it exists
if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis)
newResponseElement.removeChild(loadingEllipsis); newResponseElement.removeChild(loadingEllipsis);
} // Clear the response element if replace is true
if (replace) { if (replace) newResponseElement.innerHTML = "";
newResponseElement.innerHTML = "";
} // Append response to the response element
newResponseElement.appendChild(formatHTMLMessage(rawResponse, false, replace, rawQuery)); newResponseElement.appendChild(formatHTMLMessage(rawResponse, false, replace, rawQuery));
// Append loading ellipsis if it exists
if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis);
// Scroll to bottom of chat view
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
} }
function handleCompiledReferences(rawResponseElement, chunk, references, rawResponse) {
const additionalResponse = chunk.split("### compiled references:")[0];
rawResponse += additionalResponse;
rawResponseElement.innerHTML = "";
rawResponseElement.appendChild(formatHTMLMessage(rawResponse));
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
if (rawReferenceAsJson instanceof Array) {
references["notes"] = rawReferenceAsJson;
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
references["online"] = rawReferenceAsJson;
}
return { rawResponse, references };
}
function handleImageResponse(imageJson, rawResponse) { function handleImageResponse(imageJson, rawResponse) {
if (imageJson.image) { if (imageJson.image) {
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image"; const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
@ -785,35 +741,139 @@ To get started, just start typing below. You can also type / to see a list of co
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
} }
} }
let references = {};
if (imageJson.context && imageJson.context.length > 0) {
const rawReferenceAsJson = imageJson.context;
if (rawReferenceAsJson instanceof Array) {
references["notes"] = rawReferenceAsJson;
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
references["online"] = rawReferenceAsJson;
}
}
if (imageJson.detail) {
// If response has detail field, response is an error message. // If response has detail field, response is an error message.
rawResponse += imageJson.detail; if (imageJson.detail) rawResponse += imageJson.detail;
}
return { rawResponse, references };
}
function addMessageToChatBody(rawResponse, newResponseElement, references) { return rawResponse;
newResponseElement.innerHTML = "";
newResponseElement.appendChild(formatHTMLMessage(rawResponse));
finalizeChatBodyResponse(references, newResponseElement);
} }
function finalizeChatBodyResponse(references, newResponseElement) { function finalizeChatBodyResponse(references, newResponseElement) {
if (references != null && Object.keys(references).length > 0) { if (!!newResponseElement && references != null && Object.keys(references).length > 0) {
newResponseElement.appendChild(createReferenceSection(references)); newResponseElement.appendChild(createReferenceSection(references));
} }
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled"); document.getElementById("chat-input")?.removeAttribute("disabled");
}
function convertMessageChunkToJson(rawChunk) {
// Split the chunk into lines
console.debug("Raw Event:", rawChunk);
if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) {
try {
let jsonChunk = JSON.parse(rawChunk);
if (!jsonChunk.type)
jsonChunk = {type: 'message', data: jsonChunk};
return jsonChunk;
} catch (e) {
return {type: 'message', data: rawChunk};
}
} else if (rawChunk.length > 0) {
return {type: 'message', data: rawChunk};
}
}
function processMessageChunk(rawChunk) {
const chunk = convertMessageChunkToJson(rawChunk);
console.debug("Json Event:", chunk);
if (!chunk || !chunk.type) return;
if (chunk.type ==='status') {
console.log(`status: ${chunk.data}`);
const statusMessage = chunk.data;
handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, false);
} else if (chunk.type === 'start_llm_response') {
console.log("Started streaming", new Date());
} else if (chunk.type === 'end_llm_response') {
console.log("Stopped streaming", new Date());
// Automatically respond with voice if the subscribed user has sent voice message
if (chatMessageState.isVoice && "{{ is_active }}" == "True")
textToSpeech(chatMessageState.rawResponse);
// Append any references after all the data has been streamed
finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl);
const liveQuery = chatMessageState.rawQuery;
// Reset variables
chatMessageState = {
newResponseTextEl: null,
newResponseEl: null,
loadingEllipsis: null,
references: {},
rawResponse: "",
rawQuery: liveQuery,
isVoice: false,
}
} else if (chunk.type === "references") {
chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext};
} else if (chunk.type === 'message') {
const chunkData = chunk.data;
if (typeof chunkData === 'object' && chunkData !== null) {
// If chunkData is already a JSON object
handleJsonResponse(chunkData);
} else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) {
// Try process chunk data as if it is a JSON object
try {
const jsonData = JSON.parse(chunkData.trim());
handleJsonResponse(jsonData);
} catch (e) {
chatMessageState.rawResponse += chunkData;
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
}
} else {
chatMessageState.rawResponse += chunkData;
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
}
}
}
function handleJsonResponse(jsonData) {
if (jsonData.image || jsonData.detail) {
chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse);
} else if (jsonData.response) {
chatMessageState.rawResponse = jsonData.response;
}
if (chatMessageState.newResponseTextEl) {
chatMessageState.newResponseTextEl.innerHTML = "";
chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse));
}
}
async function readChatStream(response) {
if (!response.body) return;
const reader = response.body.getReader();
const decoder = new TextDecoder();
const eventDelimiter = '␃🔚␗';
let buffer = '';
while (true) {
const { value, done } = await reader.read();
// If the stream is done
if (done) {
// Process the last chunk
processMessageChunk(buffer);
buffer = '';
break;
}
// Read chunk from stream and append it to the buffer
const chunk = decoder.decode(value, { stream: true });
console.debug("Raw Chunk:", chunk)
// Start buffering chunks until complete event is received
buffer += chunk;
// Once the buffer contains a complete event
let newEventIndex;
while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
// Extract the event from the buffer
const event = buffer.slice(0, newEventIndex);
buffer = buffer.slice(newEventIndex + eventDelimiter.length);
// Process the event
if (event) processMessageChunk(event);
}
}
} }
function incrementalChat(event) { function incrementalChat(event) {
@ -1069,17 +1129,13 @@ To get started, just start typing below. You can also type / to see a list of co
window.onload = loadChat; window.onload = loadChat;
function setupWebSocket(isVoice=false) { function initMessageState(isVoice=false) {
let chatBody = document.getElementById("chat-body");
let wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
let webSocketUrl = `${wsProtocol}//${window.location.host}/api/chat/ws`;
if (waitingForLocation) { if (waitingForLocation) {
console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available."); console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available.");
return; return;
} }
websocketState = { chatMessageState = {
newResponseTextEl: null, newResponseTextEl: null,
newResponseEl: null, newResponseEl: null,
loadingEllipsis: null, loadingEllipsis: null,
@ -1088,174 +1144,8 @@ To get started, just start typing below. You can also type / to see a list of co
rawQuery: "", rawQuery: "",
isVoice: isVoice, isVoice: isVoice,
} }
if (chatBody.dataset.conversationId) {
webSocketUrl += `?conversation_id=${chatBody.dataset.conversationId}`;
webSocketUrl += (!!region && !!city && !!countryName) && !!timezone ? `&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : '';
websocket = new WebSocket(webSocketUrl);
websocket.onmessage = function(event) {
// Get the last element in the chat-body
let chunk = event.data;
if (chunk == "start_llm_response") {
console.log("Started streaming", new Date());
} else if (chunk == "end_llm_response") {
console.log("Stopped streaming", new Date());
// Automatically respond with voice if the subscribed user has sent voice message
if (websocketState.isVoice && "{{ is_active }}" == "True")
textToSpeech(websocketState.rawResponse);
// Append any references after all the data has been streamed
finalizeChatBodyResponse(websocketState.references, websocketState.newResponseTextEl);
const liveQuery = websocketState.rawQuery;
// Reset variables
websocketState = {
newResponseTextEl: null,
newResponseEl: null,
loadingEllipsis: null,
references: {},
rawResponse: "",
rawQuery: liveQuery,
isVoice: false,
}
} else {
try {
if (chunk.includes("application/json"))
{
chunk = JSON.parse(chunk);
}
} catch (error) {
// If the chunk is not a JSON object, continue.
} }
const contentType = chunk["content-type"]
if (contentType === "application/json") {
// Handle JSON response
try {
if (chunk.image || chunk.detail) {
({rawResponse, references } = handleImageResponse(chunk, websocketState.rawResponse));
websocketState.rawResponse = rawResponse;
websocketState.references = references;
} else if (chunk.type == "status") {
handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, null, false);
} else if (chunk.type == "rate_limit") {
handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, websocketState.loadingEllipsis, true);
} else {
rawResponse = chunk.response;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
websocketState.rawResponse += chunk;
} finally {
if (chunk.type != "status" && chunk.type != "rate_limit") {
addMessageToChatBody(websocketState.rawResponse, websocketState.newResponseTextEl, websocketState.references);
}
}
} else {
// Handle streamed response of type text/event-stream or text/plain
if (chunk && chunk.includes("### compiled references:")) {
({ rawResponse, references } = handleCompiledReferences(websocketState.newResponseTextEl, chunk, websocketState.references, websocketState.rawResponse));
websocketState.rawResponse = rawResponse;
websocketState.references = references;
} else {
// If the chunk is not a JSON object, just display it as is
websocketState.rawResponse += chunk;
if (websocketState.newResponseTextEl) {
handleStreamResponse(websocketState.newResponseTextEl, websocketState.rawResponse, websocketState.rawQuery, websocketState.loadingEllipsis);
}
}
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
};
}
}
};
websocket.onclose = function(event) {
websocket = null;
console.log("WebSocket is closed now.");
let setupWebSocketButton = document.createElement("button");
setupWebSocketButton.textContent = "Reconnect to Server";
setupWebSocketButton.onclick = setupWebSocket;
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "red";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.innerHTML = "";
statusDotText.style.marginTop = "5px";
statusDotText.appendChild(setupWebSocketButton);
}
websocket.onerror = function(event) {
console.log("WebSocket error observed:", event);
}
websocket.onopen = function(event) {
console.log("WebSocket is open now.")
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "green";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.textContent = "Connected to Server";
}
}
function sendMessageViaWebSocket(isVoice=false) {
let chatBody = document.getElementById("chat-body");
var query = document.getElementById("chat-input").value.trim();
console.log(`Query: ${query}`);
if (userMessages.length >= 10) {
userMessages.shift();
}
userMessages.push(query);
resetUserMessageIndex();
// Add message by user to chat body
renderMessage(query, "you");
document.getElementById("chat-input").value = "";
autoResize();
document.getElementById("chat-input").setAttribute("disabled", "disabled");
let newResponseEl = document.createElement("div");
newResponseEl.classList.add("chat-message", "khoj");
newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
chatBody.appendChild(newResponseEl);
let newResponseTextEl = document.createElement("div");
newResponseTextEl.classList.add("chat-message-text", "khoj");
newResponseEl.appendChild(newResponseTextEl);
// Temporary status message to indicate that Khoj is thinking
let loadingEllipsis = createLoadingEllipse();
newResponseTextEl.appendChild(loadingEllipsis);
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
let chatTooltip = document.getElementById("chat-tooltip");
chatTooltip.style.display = "none";
let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled");
// Call specified Khoj API
websocket.send(query);
let rawResponse = "";
let references = {};
websocketState = {
newResponseTextEl,
newResponseEl,
loadingEllipsis,
references,
rawResponse,
rawQuery: query,
isVoice: isVoice,
}
}
var userMessages = []; var userMessages = [];
var userMessageIndex = -1; var userMessageIndex = -1;
function loadChat() { function loadChat() {
@ -1265,7 +1155,7 @@ To get started, just start typing below. You can also type / to see a list of co
let chatHistoryUrl = `/api/chat/history?client=web`; let chatHistoryUrl = `/api/chat/history?client=web`;
if (chatBody.dataset.conversationId) { if (chatBody.dataset.conversationId) {
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`; chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
setupWebSocket(); initMessageState();
loadFileFiltersFromConversation(); loadFileFiltersFromConversation();
} }
@ -1305,7 +1195,7 @@ To get started, just start typing below. You can also type / to see a list of co
let chatBody = document.getElementById("chat-body"); let chatBody = document.getElementById("chat-body");
chatBody.dataset.conversationId = response.conversation_id; chatBody.dataset.conversationId = response.conversation_id;
loadFileFiltersFromConversation(); loadFileFiltersFromConversation();
setupWebSocket(); initMessageState();
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`; chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
let agentMetadata = response.agent; let agentMetadata = response.agent;

View file

@ -62,10 +62,6 @@ class ThreadedGenerator:
self.queue.put(data) self.queue.put(data)
def close(self): def close(self):
if self.compiled_references and len(self.compiled_references) > 0:
self.queue.put(f"### compiled references:{json.dumps(self.compiled_references)}")
if self.online_results and len(self.online_results) > 0:
self.queue.put(f"### compiled references:{json.dumps(self.online_results)}")
self.queue.put(StopIteration) self.queue.put(StopIteration)

View file

@ -11,6 +11,7 @@ from bs4 import BeautifulSoup
from markdownify import markdownify from markdownify import markdownify
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ChatEvent,
extract_relevant_info, extract_relevant_info,
generate_online_subqueries, generate_online_subqueries,
infer_webpage_urls, infer_webpage_urls,
@ -56,7 +57,8 @@ async def search_online(
query += " ".join(custom_filters) query += " ".join(custom_filters)
if not is_internet_connected(): if not is_internet_connected():
logger.warn("Cannot search online as not connected to internet") logger.warn("Cannot search online as not connected to internet")
return {} yield {}
return
# Breakdown the query into subqueries to get the correct answer # Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(query, conversation_history, location) subqueries = await generate_online_subqueries(query, conversation_history, location)
@ -66,7 +68,8 @@ async def search_online(
logger.info(f"🌐 Searching the Internet for {list(subqueries)}") logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
if send_status_func: if send_status_func:
subqueries_str = "\n- " + "\n- ".join(list(subqueries)) subqueries_str = "\n- " + "\n- ".join(list(subqueries))
await send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}") async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"):
yield {ChatEvent.STATUS: event}
with timer(f"Internet searches for {list(subqueries)} took", logger): with timer(f"Internet searches for {list(subqueries)} took", logger):
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
@ -89,7 +92,8 @@ async def search_online(
logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}") logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}")
if send_status_func: if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links)) webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}") async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages] tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
@ -98,7 +102,7 @@ async def search_online(
if webpage_extract is not None: if webpage_extract is not None:
response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract} response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract}
return response_dict yield response_dict
async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]: async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
@ -127,13 +131,15 @@ async def read_webpages(
"Infer web pages to read from the query and extract relevant information from them" "Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read") logger.info(f"Inferring web pages to read")
if send_status_func: if send_status_func:
await send_status_func(f"**🧐 Inferring web pages to read**") async for event in send_status_func(f"**🧐 Inferring web pages to read**"):
yield {ChatEvent.STATUS: event}
urls = await infer_webpage_urls(query, conversation_history, location) urls = await infer_webpage_urls(query, conversation_history, location)
logger.info(f"Reading web pages at: {urls}") logger.info(f"Reading web pages at: {urls}")
if send_status_func: if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(urls)) webpage_links_str = "\n- " + "\n- ".join(list(urls))
await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}") async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(query, url) for url in urls] tasks = [read_webpage_and_extract_content(query, url) for url in urls]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
@ -141,7 +147,7 @@ async def read_webpages(
response[query]["webpages"] = [ response[query]["webpages"] = [
{"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None {"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None
] ]
return response yield response
async def read_webpage_and_extract_content( async def read_webpage_and_extract_content(

View file

@ -6,7 +6,6 @@ import os
import threading import threading
import time import time
import uuid import uuid
from random import random
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union
import cron_descriptor import cron_descriptor
@ -37,6 +36,7 @@ from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ApiUserRateLimiter, ApiUserRateLimiter,
ChatEvent,
CommonQueryParams, CommonQueryParams,
ConversationCommandRateLimiter, ConversationCommandRateLimiter,
acreate_title_from_query, acreate_title_from_query,
@ -298,11 +298,13 @@ async def extract_references_and_questions(
not ConversationCommand.Notes in conversation_commands not ConversationCommand.Notes in conversation_commands
and not ConversationCommand.Default in conversation_commands and not ConversationCommand.Default in conversation_commands
): ):
return compiled_references, inferred_queries, q yield compiled_references, inferred_queries, q
return
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user): if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.") logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
return compiled_references, inferred_queries, q yield compiled_references, inferred_queries, q
return
# Extract filter terms from user message # Extract filter terms from user message
defiltered_query = q defiltered_query = q
@ -313,7 +315,8 @@ async def extract_references_and_questions(
if not conversation: if not conversation:
logger.error(f"Conversation with id {conversation_id} not found.") logger.error(f"Conversation with id {conversation_id} not found.")
return compiled_references, inferred_queries, defiltered_query yield compiled_references, inferred_queries, defiltered_query
return
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters]) filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
using_offline_chat = False using_offline_chat = False
@ -373,7 +376,8 @@ async def extract_references_and_questions(
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}") logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
if send_status_func: if send_status_func:
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries) inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
await send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}") async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"):
yield {ChatEvent.STATUS: event}
for query in inferred_queries: for query in inferred_queries:
n_items = min(n, 3) if using_offline_chat else n n_items = min(n, 3) if using_offline_chat else n
search_results.extend( search_results.extend(
@ -392,7 +396,7 @@ async def extract_references_and_questions(
{"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results {"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results
] ]
return compiled_references, inferred_queries, defiltered_query yield compiled_references, inferred_queries, defiltered_query
@api.get("/health", response_class=Response) @api.get("/health", response_class=Response)

View file

@ -1,17 +1,17 @@
import asyncio
import json import json
import logging import logging
import math import time
from datetime import datetime from datetime import datetime
from functools import partial
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import unquote from urllib.parse import unquote
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.requests import Request from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires from starlette.authentication import requires
from starlette.websockets import WebSocketDisconnect
from websockets import ConnectionClosedOK
from khoj.app.settings import ALLOWED_HOSTS from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import ( from khoj.database.adapters import (
@ -23,19 +23,15 @@ from khoj.database.adapters import (
aget_user_name, aget_user_name,
) )
from khoj.database.models import KhojUser from khoj.database.models import KhojUser
from khoj.processor.conversation.prompts import ( from khoj.processor.conversation.prompts import help_message, no_entries_found
help_message,
no_entries_found,
no_notes_found,
)
from khoj.processor.conversation.utils import save_to_conversation_log from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.speech.text_to_speech import generate_text_to_speech from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import read_webpages, search_online from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.routers.api import extract_references_and_questions from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ApiUserRateLimiter, ApiUserRateLimiter,
ChatEvent,
CommonQueryParams, CommonQueryParams,
CommonQueryParamsClass,
ConversationCommandRateLimiter, ConversationCommandRateLimiter,
agenerate_chat_response, agenerate_chat_response,
aget_relevant_information_sources, aget_relevant_information_sources,
@ -526,141 +522,142 @@ async def set_conversation_title(
) )
@api_chat.websocket("/ws") @api_chat.get("")
async def websocket_endpoint( async def chat(
websocket: WebSocket, request: Request,
conversation_id: int, common: CommonQueryParams,
q: str,
n: int = 7,
d: float = 0.18,
stream: Optional[bool] = False,
title: Optional[str] = None,
conversation_id: Optional[int] = None,
city: Optional[str] = None, city: Optional[str] = None,
region: Optional[str] = None, region: Optional[str] = None,
country: Optional[str] = None, country: Optional[str] = None,
timezone: Optional[str] = None, timezone: Optional[str] = None,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
): ):
async def event_generator(q: str):
start_time = time.perf_counter()
ttft = None
chat_metadata: dict = {}
connection_alive = True connection_alive = True
user: KhojUser = request.user.object
event_delimiter = "␃🔚␗"
q = unquote(q)
async def send_status_update(message: str): async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive nonlocal connection_alive, ttft
if not connection_alive: if not connection_alive or await request.is_disconnected():
return
status_packet = {
"type": "status",
"message": message,
"content-type": "application/json",
}
try:
await websocket.send_text(json.dumps(status_packet))
except ConnectionClosedOK:
connection_alive = False connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") logger.warn(f"User {user} disconnected from {common.client} client")
async def send_complete_llm_response(llm_response: str):
nonlocal connection_alive
if not connection_alive:
return return
try: try:
await websocket.send_text("start_llm_response") if event_type == ChatEvent.END_LLM_RESPONSE:
await websocket.send_text(llm_response) collect_telemetry()
await websocket.send_text("end_llm_response") if event_type == ChatEvent.START_LLM_RESPONSE:
except ConnectionClosedOK: ttft = time.perf_counter() - start_time
if event_type == ChatEvent.MESSAGE:
yield data
elif event_type == ChatEvent.REFERENCES or stream:
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
except asyncio.CancelledError as e:
connection_alive = False connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") logger.warn(f"User {user} disconnected from {common.client} client: {e}")
async def send_message(message: str):
nonlocal connection_alive
if not connection_alive:
return return
try: except Exception as e:
await websocket.send_text(message)
except ConnectionClosedOK:
connection_alive = False connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
async def send_rate_limit_message(message: str):
nonlocal connection_alive
if not connection_alive:
return return
finally:
if stream:
yield event_delimiter
status_packet = { async def send_llm_response(response: str):
"type": "rate_limit", async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
"message": message, yield result
"content-type": "application/json", async for result in send_event(ChatEvent.MESSAGE, response):
} yield result
try: async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
await websocket.send_text(json.dumps(status_packet)) yield result
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
user: KhojUser = websocket.user.object def collect_telemetry():
conversation = await ConversationAdapters.aget_conversation_by_user( # Gather chat response telemetry
user, client_application=websocket.user.client_app, conversation_id=conversation_id nonlocal chat_metadata
latency = time.perf_counter() - start_time
cmd_set = set([cmd.value for cmd in conversation_commands])
chat_metadata = chat_metadata or {}
chat_metadata["conversation_command"] = cmd_set
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
chat_metadata["latency"] = f"{latency:.3f}"
chat_metadata["ttft_latency"] = f"{ttft:.3f}"
logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
logger.info(f"Chat response total time: {latency:.3f} seconds")
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
client=request.user.client_app,
user_agent=request.headers.get("user-agent"),
host=request.headers.get("host"),
metadata=chat_metadata,
) )
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") conversation = await ConversationAdapters.aget_conversation_by_user(
user, client_application=request.user.client_app, conversation_id=conversation_id, title=title
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") )
if not conversation:
async for result in send_llm_response(f"Conversation {conversation_id} not found"):
yield result
return
await is_ready_to_chat(user) await is_ready_to_chat(user)
user_name = await aget_user_name(user) user_name = await aget_user_name(user)
location = None location = None
if city or region or country: if city or region or country:
location = LocationData(city=city, region=region, country=country) location = LocationData(city=city, region=region, country=country)
await websocket.accept()
while connection_alive:
try:
if conversation:
await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"])
q = await websocket.receive_text()
# Refresh these because the connection to the database might have been closed
await conversation.arefresh_from_db()
except WebSocketDisconnect:
logger.debug(f"User {user} disconnected web socket")
break
try:
await sync_to_async(hourly_limiter)(websocket)
await sync_to_async(daily_limiter)(websocket)
except HTTPException as e:
await send_rate_limit_message(e.detail)
break
if is_query_empty(q): if is_query_empty(q):
await send_message("start_llm_response") async for result in send_llm_response("Please ask your query to get started."):
await send_message( yield result
"It seems like your query is incomplete. Could you please provide more details or specify what you need help with?" return
)
await send_message("end_llm_response")
continue
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_commands = [get_conversation_command(query=q, any_references=True)] conversation_commands = [get_conversation_command(query=q, any_references=True)]
await send_status_update(f"**👀 Understanding Query**: {q}") async for result in send_event(ChatEvent.STATUS, f"**👀 Understanding Query**: {q}"):
yield result
meta_log = conversation.conversation_log meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
if conversation_commands == [ConversationCommand.Default] or is_automated_task: if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
await send_status_update(f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}") async for result in send_event(
ChatEvent.STATUS, f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
await send_status_update(f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}") async for result in send_event(ChatEvent.STATUS, f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands: if mode not in conversation_commands:
conversation_commands.append(mode) conversation_commands.append(mode)
for cmd in conversation_commands: for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd) await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip() q = q.replace(f"/{cmd.value}", "").strip()
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else [] file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if # Skip trying to summarize if
if ( if (
@ -676,28 +673,37 @@ async def websocket_endpoint(
response_log = "" response_log = ""
if len(file_filters) == 0: if len(file_filters) == 0:
response_log = "No files selected for summarization. Please add files using the section on the left." response_log = "No files selected for summarization. Please add files using the section on the left."
await send_complete_llm_response(response_log) async for result in send_llm_response(response_log):
yield result
elif len(file_filters) > 1: elif len(file_filters) > 1:
response_log = "Only one file can be selected for summarization." response_log = "Only one file can be selected for summarization."
await send_complete_llm_response(response_log) async for result in send_llm_response(response_log):
yield result
else: else:
try: try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0: if len(file_object) == 0:
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
await send_complete_llm_response(response_log) async for result in send_llm_response(response_log):
continue yield result
return
contextual_data = " ".join([file.raw_text for file in file_object]) contextual_data = " ".join([file.raw_text for file in file_object])
if not q: if not q:
q = "Create a general summary of the file" q = "Create a general summary of the file"
await send_status_update(f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}") async for result in send_event(
ChatEvent.STATUS, f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}"
):
yield result
response = await extract_relevant_summary(q, contextual_data) response = await extract_relevant_summary(q, contextual_data)
response_log = str(response) response_log = str(response)
await send_complete_llm_response(response_log) async for result in send_llm_response(response_log):
yield result
except Exception as e: except Exception as e:
response_log = "Error summarizing file." response_log = "Error summarizing file."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
await send_complete_llm_response(response_log) async for result in send_llm_response(response_log):
yield result
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
q, q,
response_log, response_log,
@ -705,16 +711,10 @@ async def websocket_endpoint(
meta_log, meta_log,
user_message_time, user_message_time,
intent_type="summarize", intent_type="summarize",
client_application=websocket.user.client_app, client_application=request.user.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
) )
update_telemetry_state( return
request=websocket,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
continue
custom_filters = [] custom_filters = []
if conversation_commands == [ConversationCommand.Help]: if conversation_commands == [ConversationCommand.Help]:
@ -724,8 +724,9 @@ async def websocket_endpoint(
conversation_config = await ConversationAdapters.aget_default_conversation_config() conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
await send_complete_llm_response(formatted_help) async for result in send_llm_response(formatted_help):
continue yield result
return
# Adding specification to search online specifically on khoj.dev pages. # Adding specification to search online specifically on khoj.dev pages.
custom_filters.append("site:khoj.dev") custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online) conversation_commands.append(ConversationCommand.Online)
@ -733,14 +734,14 @@ async def websocket_endpoint(
if ConversationCommand.Automation in conversation_commands: if ConversationCommand.Automation in conversation_commands:
try: try:
automation, crontime, query_to_run, subject = await create_automation( automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, websocket.url, meta_log q, timezone, user, request.url, meta_log
) )
except Exception as e: except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}") logger.error(f"Error scheduling task {q} for {user.email}: {e}")
await send_complete_llm_response( error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
f"Unable to create automation. Ensure the automation doesn't already exist." async for result in send_llm_response(error_message):
) yield result
continue return
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
@ -750,57 +751,78 @@ async def websocket_endpoint(
meta_log, meta_log,
user_message_time, user_message_time,
intent_type="automation", intent_type="automation",
client_application=websocket.user.client_app, client_application=request.user.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
inferred_queries=[query_to_run], inferred_queries=[query_to_run],
automation_id=automation.id, automation_id=automation.id,
) )
common = CommonQueryParamsClass( async for result in send_llm_response(llm_response):
client=websocket.user.client_app, yield result
user_agent=websocket.headers.get("user-agent"), return
host=websocket.headers.get("host"),
)
update_telemetry_state(
request=websocket,
telemetry_type="api",
api="chat",
**common.__dict__,
)
await send_complete_llm_response(llm_response)
continue
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( # Gather Context
websocket, meta_log, q, 7, 0.18, conversation_id, conversation_commands, location, send_status_update ## Extract Document References
) compiled_references, inferred_queries, defiltered_query = [], [], None
async for result in extract_references_and_questions(
request,
meta_log,
q,
(n or 7),
(d or 0.18),
conversation_id,
conversation_commands,
location,
partial(send_event, ChatEvent.STATUS),
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
if compiled_references: if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
await send_status_update(f"**📜 Found Relevant Notes**: {headings}") async for result in send_event(ChatEvent.STATUS, f"**📜 Found Relevant Notes**: {headings}"):
yield result
online_results: Dict = dict() online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
await send_complete_llm_response(f"{no_entries_found.format()}") async for result in send_llm_response(f"{no_entries_found.format()}"):
continue yield result
return
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes) conversation_commands.remove(ConversationCommand.Notes)
## Gather Online References
if ConversationCommand.Online in conversation_commands: if ConversationCommand.Online in conversation_commands:
try: try:
online_results = await search_online( async for result in search_online(
defiltered_query, meta_log, location, send_status_update, custom_filters defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters
) ):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
online_results = result
except ValueError as e: except ValueError as e:
logger.warning(f"Error searching online: {e}. Attempting to respond without online results") error_message = f"Error searching online: {e}. Attempting to respond without online results"
await send_complete_llm_response( logger.warning(error_message)
f"Error searching online: {e}. Attempting to respond without online results" async for result in send_llm_response(error_message):
) yield result
continue return
## Gather Webpage References
if ConversationCommand.Webpage in conversation_commands: if ConversationCommand.Webpage in conversation_commands:
try: try:
direct_web_pages = await read_webpages(defiltered_query, meta_log, location, send_status_update) async for result in read_webpages(
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS)
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
direct_web_pages = result
webpages = [] webpages = []
for query in direct_web_pages: for query in direct_web_pages:
if online_results.get(query): if online_results.get(query):
@ -810,304 +832,52 @@ async def websocket_endpoint(
for webpage in direct_web_pages[query]["webpages"]: for webpage in direct_web_pages[query]["webpages"]:
webpages.append(webpage["link"]) webpages.append(webpage["link"])
async for result in send_event(ChatEvent.STATUS, f"**📚 Read web pages**: {webpages}"):
await send_status_update(f"**📚 Read web pages**: {webpages}") yield result
except ValueError as e: except ValueError as e:
logger.warning( logger.warning(
f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True f"Error directly reading webpages: {e}. Attempting to respond without online results",
exc_info=True,
) )
## Send Gathered References
async for result in send_event(
ChatEvent.REFERENCES,
{
"inferredQueries": inferred_queries,
"context": compiled_references,
"onlineContext": online_results,
},
):
yield result
# Generate Output
## Generate Image Output
if ConversationCommand.Image in conversation_commands: if ConversationCommand.Image in conversation_commands:
update_telemetry_state( async for result in text_to_image(
request=websocket,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
image, status_code, improved_image_prompt, intent_type = await text_to_image(
q, q,
user, user,
meta_log, meta_log,
location_data=location, location_data=location,
references=compiled_references, references=compiled_references,
online_results=online_results, online_results=online_results,
send_status_func=send_status_update, send_status_func=partial(send_event, ChatEvent.STATUS),
) ):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
image, status_code, improved_image_prompt, intent_type = result
if image is None or status_code != 200: if image is None or status_code != 200:
content_obj = { content_obj = {
"image": image, "content-type": "application/json",
"intentType": intent_type, "intentType": intent_type,
"detail": improved_image_prompt, "detail": improved_image_prompt,
"content-type": "application/json", "image": image,
} }
await send_complete_llm_response(json.dumps(content_obj)) async for result in send_llm_response(json.dumps(content_obj)):
continue yield result
return
await sync_to_async(save_to_conversation_log)(
q,
image,
user,
meta_log,
user_message_time,
intent_type=intent_type,
inferred_queries=[improved_image_prompt],
client_application=websocket.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
)
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore
await send_complete_llm_response(json.dumps(content_obj))
continue
await send_status_update(f"**💭 Generating a well-informed response**")
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
conversation,
compiled_references,
online_results,
inferred_queries,
conversation_commands,
user,
websocket.user.client_app,
conversation_id,
location,
user_name,
)
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
update_telemetry_state(
request=websocket,
telemetry_type="api",
api="chat",
metadata=chat_metadata,
)
iterator = AsyncIteratorWrapper(llm_response)
await send_message("start_llm_response")
async for item in iterator:
if item is None:
break
if connection_alive:
try:
await send_message(f"{item}")
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
await send_message("end_llm_response")
@api_chat.get("", response_class=Response)
@requires(["authenticated"])
async def chat(
request: Request,
common: CommonQueryParams,
q: str,
n: Optional[int] = 5,
d: Optional[float] = 0.22,
stream: Optional[bool] = False,
title: Optional[str] = None,
conversation_id: Optional[int] = None,
city: Optional[str] = None,
region: Optional[str] = None,
country: Optional[str] = None,
timezone: Optional[str] = None,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
) -> Response:
user: KhojUser = request.user.object
q = unquote(q)
if is_query_empty(q):
return Response(
content="It seems like your query is incomplete. Could you please provide more details or specify what you need help with?",
media_type="text/plain",
status_code=400,
)
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.info(f"Chat request by {user.username}: {q}")
await is_ready_to_chat(user)
conversation_commands = [get_conversation_command(query=q, any_references=True)]
_custom_filters = []
if conversation_commands == [ConversationCommand.Help]:
help_str = "/" + ConversationCommand.Help
if q.strip() == help_str:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
# Adding specification to search online specifically on khoj.dev pages.
_custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online)
conversation = await ConversationAdapters.aget_conversation_by_user(
user, request.user.client_app, conversation_id, title
)
conversation_id = conversation.id if conversation else None
if not conversation:
return Response(
content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400
)
else:
meta_log = conversation.conversation_log
if ConversationCommand.Summarize in conversation_commands:
file_filters = conversation.file_filters
llm_response = ""
if len(file_filters) == 0:
llm_response = "No files selected for summarization. Please add files using the section on the left."
elif len(file_filters) > 1:
llm_response = "Only one file can be selected for summarization."
else:
try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
llm_response = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
contextual_data = " ".join([file.raw_text for file in file_object])
summarizeStr = "/" + ConversationCommand.Summarize
if q.strip() == summarizeStr:
q = "Create a general summary of the file"
response = await extract_relevant_summary(q, contextual_data)
llm_response = str(response)
except Exception as e:
logger.error(f"Error summarizing file for {user.email}: {e}")
llm_response = "Error summarizing file."
await sync_to_async(save_to_conversation_log)(
q,
llm_response,
user,
conversation.conversation_log,
user_message_time,
intent_type="summarize",
client_application=request.user.client_app,
conversation_id=conversation_id,
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__,
)
return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
if mode not in conversation_commands:
conversation_commands.append(mode)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
location = None
if city or region or country:
location = LocationData(city=city, region=region, country=country)
user_name = await aget_user_name(user)
if ConversationCommand.Automation in conversation_commands:
try:
automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, request.url, meta_log
)
except Exception as e:
logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True)
return Response(
content=f"Unable to create automation. Ensure the automation doesn't already exist.",
media_type="text/plain",
status_code=500,
)
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
await sync_to_async(save_to_conversation_log)(
q,
llm_response,
user,
meta_log,
user_message_time,
intent_type="automation",
client_application=request.user.client_app,
conversation_id=conversation_id,
inferred_queries=[query_to_run],
automation_id=automation.id,
)
if stream:
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
else:
return Response(content=llm_response, media_type="text/plain", status_code=200)
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, meta_log, q, (n or 5), (d or math.inf), conversation_id, conversation_commands, location
)
online_results: Dict[str, Dict] = {}
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
no_entries_found_format = no_entries_found.format()
if stream:
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
else:
response_obj = {"response": no_entries_found_format}
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
no_notes_found_format = no_notes_found.format()
if stream:
return StreamingResponse(iter([no_notes_found_format]), media_type="text/event-stream", status_code=200)
else:
response_obj = {"response": no_notes_found_format}
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands:
try:
online_results = await search_online(defiltered_query, meta_log, location, custom_filters=_custom_filters)
except ValueError as e:
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
if ConversationCommand.Webpage in conversation_commands:
try:
online_results = await read_webpages(defiltered_query, meta_log, location)
except ValueError as e:
logger.warning(
f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
)
if ConversationCommand.Image in conversation_commands:
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__,
)
image, status_code, improved_image_prompt, intent_type = await text_to_image(
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
)
if image is None:
content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
q, q,
@ -1118,14 +888,22 @@ async def chat(
intent_type=intent_type, intent_type=intent_type,
inferred_queries=[improved_image_prompt], inferred_queries=[improved_image_prompt],
client_application=request.user.client_app, client_application=request.user.client_app,
conversation_id=conversation.id, conversation_id=conversation_id,
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
) )
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore content_obj = {
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) "intentType": intent_type,
"inferredQueries": [improved_image_prompt],
"image": image,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
return
# Get the (streamed) chat response from the LLM of choice. ## Generate Text Output
async for result in send_event(ChatEvent.STATUS, f"**💭 Generating a well-informed response**"):
yield result
llm_response, chat_metadata = await agenerate_chat_response( llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query, defiltered_query,
meta_log, meta_log,
@ -1136,45 +914,48 @@ async def chat(
conversation_commands, conversation_commands,
user, user,
request.user.client_app, request.user.client_app,
conversation.id, conversation_id,
location, location,
user_name, user_name,
) )
cmd_set = set([cmd.value for cmd in conversation_commands]) # Send Response
chat_metadata["conversation_command"] = cmd_set async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None yield result
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata=chat_metadata,
**common.__dict__,
)
if llm_response is None:
return Response(content=llm_response, media_type="text/plain", status_code=500)
if stream:
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
continue_stream = True
iterator = AsyncIteratorWrapper(llm_response) iterator = AsyncIteratorWrapper(llm_response)
# Get the full response from the generator if the stream is not requested.
aggregated_gpt_response = ""
async for item in iterator: async for item in iterator:
if item is None: if item is None:
break async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
aggregated_gpt_response += item yield result
logger.debug("Finished streaming response")
return
if not connection_alive or not continue_stream:
continue
try:
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
yield result
except Exception as e:
continue_stream = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
actual_response = aggregated_gpt_response.split("### compiled references:")[0] ## Stream Text Response
if stream:
response_obj = { return StreamingResponse(event_generator(q), media_type="text/plain")
"response": actual_response, ## Non-Streaming Text Response
"inferredQueries": inferred_queries, else:
"context": compiled_references, # Get the full response from the generator if the stream is not requested.
"online_results": online_results, response_obj = {}
} actual_response = ""
iterator = event_generator(q)
async for item in iterator:
try:
item_json = json.loads(item)
if "type" in item_json and item_json["type"] == ChatEvent.REFERENCES.value:
response_obj = item_json["data"]
except:
actual_response += item
response_obj["response"] = actual_response
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200) return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)

View file

@ -8,6 +8,7 @@ import math
import re import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from enum import Enum
from functools import partial from functools import partial
from random import random from random import random
from typing import ( from typing import (
@ -753,7 +754,7 @@ async def text_to_image(
references: List[Dict[str, Any]], references: List[Dict[str, Any]],
online_results: Dict[str, Any], online_results: Dict[str, Any],
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
) -> Tuple[Optional[str], int, Optional[str], str]: ):
status_code = 200 status_code = 200
image = None image = None
response = None response = None
@ -765,7 +766,8 @@ async def text_to_image(
# If the user has not configured a text to image model, return an unsupported on server error # If the user has not configured a text to image model, return an unsupported on server error
status_code = 501 status_code = 501
message = "Failed to generate image. Setup image generation on the server." message = "Failed to generate image. Setup image generation on the server."
return image_url or image, status_code, message, intent_type.value yield image_url or image, status_code, message, intent_type.value
return
text2image_model = text_to_image_config.model_name text2image_model = text_to_image_config.model_name
chat_history = "" chat_history = ""
@ -777,9 +779,9 @@ async def text_to_image(
chat_history += f"Q: Prompt: {chat['intent']['query']}\n" chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n" chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
with timer("Improve the original user query", logger):
if send_status_func: if send_status_func:
await send_status_func("**✍🏽 Enhancing the Painting Prompt**") async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"):
yield {ChatEvent.STATUS: event}
improved_image_prompt = await generate_better_image_prompt( improved_image_prompt = await generate_better_image_prompt(
message, message,
chat_history, chat_history,
@ -790,7 +792,8 @@ async def text_to_image(
) )
if send_status_func: if send_status_func:
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}") async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"):
yield {ChatEvent.STATUS: event}
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
with timer("Generate image with OpenAI", logger): with timer("Generate image with OpenAI", logger):
@ -815,12 +818,14 @@ async def text_to_image(
logger.error(f"Image Generation blocked by OpenAI: {e}") logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
return image_url or image, status_code, message, intent_type.value yield image_url or image, status_code, message, intent_type.value
return
else: else:
logger.error(f"Image Generation failed with {e}", exc_info=True) logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value yield image_url or image, status_code, message, intent_type.value
return
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI: elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
with timer("Generate image with Stability AI", logger): with timer("Generate image with Stability AI", logger):
@ -842,7 +847,8 @@ async def text_to_image(
logger.error(f"Image Generation failed with {e}", exc_info=True) logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with Stability AI error: {e}" message = f"Image generation failed with Stability AI error: {e}"
status_code = e.status_code # type: ignore status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value yield image_url or image, status_code, message, intent_type.value
return
with timer("Convert image to webp", logger): with timer("Convert image to webp", logger):
# Convert png to webp for faster loading # Convert png to webp for faster loading
@ -862,7 +868,7 @@ async def text_to_image(
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8") image = base64.b64encode(webp_image_bytes).decode("utf-8")
return image_url or image, status_code, improved_image_prompt, intent_type.value yield image_url or image, status_code, improved_image_prompt, intent_type.value
class ApiUserRateLimiter: class ApiUserRateLimiter:
@ -1184,3 +1190,11 @@ def construct_automation_created_message(automation: Job, crontime: str, query_t
Manage your automations [here](/automations). Manage your automations [here](/automations).
""".strip() """.strip()
class ChatEvent(Enum):
START_LLM_RESPONSE = "start_llm_response"
END_LLM_RESPONSE = "end_llm_response"
MESSAGE = "message"
REFERENCES = "references"
STATUS = "status"

View file

@ -22,7 +22,7 @@ magika = Magika()
def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict: def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict:
files = {} files: dict[str, dict] = {"docx": {}, "image": {}}
if search_type == SearchType.All or search_type == SearchType.Org: if search_type == SearchType.All or search_type == SearchType.Org:
org_config = LocalOrgConfig.objects.filter(user=user).first() org_config = LocalOrgConfig.objects.filter(user=user).first()

View file

@ -455,13 +455,13 @@ def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiU
@pytest.mark.skipif(os.getenv("OPENAI_API_KEY") is None, reason="requires OPENAI_API_KEY") @pytest.mark.skipif(os.getenv("OPENAI_API_KEY") is None, reason="requires OPENAI_API_KEY")
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser): async def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser):
# Arrange # Arrange
headers = {"Authorization": f"Bearer {api_user2.token}"} headers = {"Authorization": f"Bearer {api_user2.token}"}
# Act # Act
auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true', headers=headers) auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"', headers=headers)
no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true') no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"')
# Assert # Assert
assert auth_response.status_code == 200 assert auth_response.status_code == 200

View file

@ -67,10 +67,8 @@ def test_chat_with_online_content(client_offline_chat):
# Act # Act
q = "/online give me the link to paul graham's essay how to do great work" q = "/online give me the link to paul graham's essay how to do great work"
encoded_q = quote(q, safe="") encoded_q = quote(q, safe="")
response = client_offline_chat.get(f"/api/chat?q={encoded_q}&stream=true") response = client_offline_chat.get(f"/api/chat?q={encoded_q}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
response_message = response_message.split("### compiled references")[0]
# Assert # Assert
expected_responses = [ expected_responses = [
@ -91,10 +89,8 @@ def test_chat_with_online_webpage_content(client_offline_chat):
# Act # Act
q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?" q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
encoded_q = quote(q, safe="") encoded_q = quote(q, safe="")
response = client_offline_chat.get(f"/api/chat?q={encoded_q}&stream=true") response = client_offline_chat.get(f"/api/chat?q={encoded_q}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
response_message = response_message.split("### compiled references")[0]
# Assert # Assert
expected_responses = ["185", "1871", "horse"] expected_responses = ["185", "1871", "horse"]

View file

@ -49,8 +49,8 @@ def create_conversation(message_list, user, agent=None):
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_chat_with_no_chat_history_or_retrieved_content(chat_client): def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# Act # Act
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true') response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"')
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
expected_responses = ["Khoj", "khoj"] expected_responses = ["Khoj", "khoj"]
@ -67,10 +67,8 @@ def test_chat_with_online_content(chat_client):
# Act # Act
q = "/online give me the link to paul graham's essay how to do great work" q = "/online give me the link to paul graham's essay how to do great work"
encoded_q = quote(q, safe="") encoded_q = quote(q, safe="")
response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true") response = chat_client.get(f"/api/chat?q={encoded_q}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
response_message = response_message.split("### compiled references")[0]
# Assert # Assert
expected_responses = [ expected_responses = [
@ -91,10 +89,8 @@ def test_chat_with_online_webpage_content(chat_client):
# Act # Act
q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?" q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
encoded_q = quote(q, safe="") encoded_q = quote(q, safe="")
response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true") response = chat_client.get(f"/api/chat?q={encoded_q}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
response_message = response_message.split("### compiled references")[0]
# Assert # Assert
expected_responses = ["185", "1871", "horse"] expected_responses = ["185", "1871", "horse"]
@ -144,7 +140,7 @@ def test_answer_from_currently_retrieved_content(chat_client, default_user2: Kho
# Act # Act
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"') response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@ -168,7 +164,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_n
# Act # Act
response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"') response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"')
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@ -191,7 +187,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, d
# Act # Act
response = chat_client.get(f'/api/chat?q="Where was I born?"') response = chat_client.get(f'/api/chat?q="Where was I born?"')
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@ -215,8 +211,8 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true') response = chat_client.get(f'/api/chat?q="Where was I born?"')
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
expected_responses = [ expected_responses = [
@ -226,6 +222,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use
"do not have", "do not have",
"don't have", "don't have",
"where were you born?", "where were you born?",
"where you were born?",
] ]
assert response.status_code == 200 assert response.status_code == 200
@ -280,8 +277,8 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true") response = chat_client_no_background.get(f"/api/chat?q={query}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@ -527,8 +524,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else."&stream=true') response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else.')
response_message = response.content.decode("utf-8").split("### compiled references")[0] response_message = response.json()["response"]
# Assert # Assert
expected_responses = ["test", "Test"] expected_responses = ["test", "Test"]
@ -544,9 +541,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
@pytest.mark.chatquality @pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background): def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background):
# Act # Act
response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son?"')
response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son?"&stream=true') response_message = response.json()["response"].lower()
response_message = response.content.decode("utf-8").split("### compiled references")[0].lower()
# Assert # Assert
expected_responses = [ expected_responses = [
@ -658,8 +654,8 @@ def test_answer_in_chat_history_by_conversation_id_with_agent(
def test_answer_requires_multiple_independent_searches(chat_client): def test_answer_requires_multiple_independent_searches(chat_client):
"Chat director should be able to answer by doing multiple independent searches for required information" "Chat director should be able to answer by doing multiple independent searches for required information"
# Act # Act
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"&stream=true') response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"')
response_message = response.content.decode("utf-8").split("### compiled references")[0].lower() response_message = response.json()["response"].lower()
# Assert # Assert
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"] expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
@ -683,8 +679,8 @@ def test_answer_using_file_filter(chat_client):
'Is Xi older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"' 'Is Xi older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"'
) )
response = chat_client.get(f"/api/chat?q={query}&stream=true") response = chat_client.get(f"/api/chat?q={query}")
response_message = response.content.decode("utf-8").split("### compiled references")[0].lower() response_message = response.json()["response"].lower()
# Assert # Assert
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"] expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]