mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Merge pull request #858 from khoj-ai/use-sse-instead-of-websocket
Use Single HTTP API for Robust, Generalizable Chat Streaming
This commit is contained in:
commit
377f7668c5
17 changed files with 996 additions and 1261 deletions
|
@ -61,6 +61,14 @@
|
||||||
let city = null;
|
let city = null;
|
||||||
let countryName = null;
|
let countryName = null;
|
||||||
let timezone = null;
|
let timezone = null;
|
||||||
|
let chatMessageState = {
|
||||||
|
newResponseTextEl: null,
|
||||||
|
newResponseEl: null,
|
||||||
|
loadingEllipsis: null,
|
||||||
|
references: {},
|
||||||
|
rawResponse: "",
|
||||||
|
isVoice: false,
|
||||||
|
}
|
||||||
|
|
||||||
fetch("https://ipapi.co/json")
|
fetch("https://ipapi.co/json")
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
|
@ -75,10 +83,9 @@
|
||||||
return;
|
return;
|
||||||
});
|
});
|
||||||
|
|
||||||
async function chat() {
|
async function chat(isVoice=false) {
|
||||||
// Extract required fields for search from form
|
// Extract chat message from chat input form
|
||||||
let query = document.getElementById("chat-input").value.trim();
|
let query = document.getElementById("chat-input").value.trim();
|
||||||
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
|
|
||||||
console.log(`Query: ${query}`);
|
console.log(`Query: ${query}`);
|
||||||
|
|
||||||
// Short circuit on empty query
|
// Short circuit on empty query
|
||||||
|
@ -106,9 +113,6 @@
|
||||||
await refreshChatSessionsPanel();
|
await refreshChatSessionsPanel();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate backend API URL to execute query
|
|
||||||
let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`;
|
|
||||||
|
|
||||||
let newResponseEl = document.createElement("div");
|
let newResponseEl = document.createElement("div");
|
||||||
newResponseEl.classList.add("chat-message", "khoj");
|
newResponseEl.classList.add("chat-message", "khoj");
|
||||||
newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
|
newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
|
||||||
|
@ -119,25 +123,7 @@
|
||||||
newResponseEl.appendChild(newResponseTextEl);
|
newResponseEl.appendChild(newResponseTextEl);
|
||||||
|
|
||||||
// Temporary status message to indicate that Khoj is thinking
|
// Temporary status message to indicate that Khoj is thinking
|
||||||
let loadingEllipsis = document.createElement("div");
|
let loadingEllipsis = createLoadingEllipsis();
|
||||||
loadingEllipsis.classList.add("lds-ellipsis");
|
|
||||||
|
|
||||||
let firstEllipsis = document.createElement("div");
|
|
||||||
firstEllipsis.classList.add("lds-ellipsis-item");
|
|
||||||
|
|
||||||
let secondEllipsis = document.createElement("div");
|
|
||||||
secondEllipsis.classList.add("lds-ellipsis-item");
|
|
||||||
|
|
||||||
let thirdEllipsis = document.createElement("div");
|
|
||||||
thirdEllipsis.classList.add("lds-ellipsis-item");
|
|
||||||
|
|
||||||
let fourthEllipsis = document.createElement("div");
|
|
||||||
fourthEllipsis.classList.add("lds-ellipsis-item");
|
|
||||||
|
|
||||||
loadingEllipsis.appendChild(firstEllipsis);
|
|
||||||
loadingEllipsis.appendChild(secondEllipsis);
|
|
||||||
loadingEllipsis.appendChild(thirdEllipsis);
|
|
||||||
loadingEllipsis.appendChild(fourthEllipsis);
|
|
||||||
|
|
||||||
newResponseTextEl.appendChild(loadingEllipsis);
|
newResponseTextEl.appendChild(loadingEllipsis);
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
@ -148,107 +134,36 @@
|
||||||
let chatInput = document.getElementById("chat-input");
|
let chatInput = document.getElementById("chat-input");
|
||||||
chatInput.classList.remove("option-enabled");
|
chatInput.classList.remove("option-enabled");
|
||||||
|
|
||||||
|
// Setup chat message state
|
||||||
|
chatMessageState = {
|
||||||
|
newResponseTextEl,
|
||||||
|
newResponseEl,
|
||||||
|
loadingEllipsis,
|
||||||
|
references: {},
|
||||||
|
rawResponse: "",
|
||||||
|
rawQuery: query,
|
||||||
|
isVoice: isVoice,
|
||||||
|
}
|
||||||
|
|
||||||
// Call Khoj chat API
|
// Call Khoj chat API
|
||||||
let response = await fetch(chatApi, { headers });
|
let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationID}&stream=true&client=desktop`;
|
||||||
let rawResponse = "";
|
chatApi += (!!region && !!city && !!countryName && !!timezone)
|
||||||
let references = null;
|
? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`
|
||||||
const contentType = response.headers.get("content-type");
|
: '';
|
||||||
|
|
||||||
|
const response = await fetch(chatApi, { headers });
|
||||||
|
|
||||||
if (contentType === "application/json") {
|
|
||||||
// Handle JSON response
|
|
||||||
try {
|
try {
|
||||||
const responseAsJson = await response.json();
|
if (!response.ok) throw new Error(response.statusText);
|
||||||
if (responseAsJson.image) {
|
if (!response.body) throw new Error("Response body is empty");
|
||||||
// If response has image field, response is a generated image.
|
// Stream and render chat response
|
||||||
if (responseAsJson.intentType === "text-to-image") {
|
await readChatStream(response);
|
||||||
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
|
} catch (err) {
|
||||||
} else if (responseAsJson.intentType === "text-to-image2") {
|
console.error(`Khoj chat response failed with\n${err}`);
|
||||||
rawResponse += `![${query}](${responseAsJson.image})`;
|
if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis)
|
||||||
} else if (responseAsJson.intentType === "text-to-image-v3") {
|
chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
|
||||||
rawResponse += `![${query}](data:image/webp;base64,${responseAsJson.image})`;
|
let errorMsg = "Sorry, unable to get response from Khoj backend ❤️🩹. Retry or contact developers for help at <a href=mailto:'team@khoj.dev'>team@khoj.dev</a> or <a href='https://discord.gg/BDgyabRM6e'>on Discord</a>";
|
||||||
}
|
newResponseTextEl.textContent = errorMsg;
|
||||||
const inferredQueries = responseAsJson.inferredQueries?.[0];
|
|
||||||
if (inferredQueries) {
|
|
||||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (responseAsJson.context) {
|
|
||||||
const rawReferenceAsJson = responseAsJson.context;
|
|
||||||
references = createReferenceSection(rawReferenceAsJson);
|
|
||||||
}
|
|
||||||
if (responseAsJson.detail) {
|
|
||||||
// If response has detail field, response is an error message.
|
|
||||||
rawResponse += responseAsJson.detail;
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
rawResponse += chunk;
|
|
||||||
} finally {
|
|
||||||
newResponseTextEl.innerHTML = "";
|
|
||||||
newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
if (references != null) {
|
|
||||||
newResponseTextEl.appendChild(references);
|
|
||||||
}
|
|
||||||
|
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
|
||||||
document.getElementById("chat-input").removeAttribute("disabled");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Handle streamed response of type text/event-stream or text/plain
|
|
||||||
const reader = response.body.getReader();
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
let references = {};
|
|
||||||
|
|
||||||
readStream();
|
|
||||||
|
|
||||||
function readStream() {
|
|
||||||
reader.read().then(({ done, value }) => {
|
|
||||||
if (done) {
|
|
||||||
// Append any references after all the data has been streamed
|
|
||||||
if (references != {}) {
|
|
||||||
newResponseTextEl.appendChild(createReferenceSection(references));
|
|
||||||
}
|
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
|
||||||
document.getElementById("chat-input").removeAttribute("disabled");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode message chunk from stream
|
|
||||||
const chunk = decoder.decode(value, { stream: true });
|
|
||||||
|
|
||||||
if (chunk.includes("### compiled references:")) {
|
|
||||||
const additionalResponse = chunk.split("### compiled references:")[0];
|
|
||||||
rawResponse += additionalResponse;
|
|
||||||
newResponseTextEl.innerHTML = "";
|
|
||||||
newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
const rawReference = chunk.split("### compiled references:")[1];
|
|
||||||
const rawReferenceAsJson = JSON.parse(rawReference);
|
|
||||||
if (rawReferenceAsJson instanceof Array) {
|
|
||||||
references["notes"] = rawReferenceAsJson;
|
|
||||||
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
|
||||||
references["online"] = rawReferenceAsJson;
|
|
||||||
}
|
|
||||||
readStream();
|
|
||||||
} else {
|
|
||||||
// Display response from Khoj
|
|
||||||
if (newResponseTextEl.getElementsByClassName("lds-ellipsis").length > 0) {
|
|
||||||
newResponseTextEl.removeChild(loadingEllipsis);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
rawResponse += chunk;
|
|
||||||
newResponseTextEl.innerHTML = "";
|
|
||||||
newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
readStream();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scroll to bottom of chat window as chat response is streamed
|
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -364,3 +364,194 @@ function createReferenceSection(references, createLinkerSection=false) {
|
||||||
|
|
||||||
return referencesDiv;
|
return referencesDiv;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function createLoadingEllipsis() {
|
||||||
|
let loadingEllipsis = document.createElement("div");
|
||||||
|
loadingEllipsis.classList.add("lds-ellipsis");
|
||||||
|
|
||||||
|
let firstEllipsis = document.createElement("div");
|
||||||
|
firstEllipsis.classList.add("lds-ellipsis-item");
|
||||||
|
|
||||||
|
let secondEllipsis = document.createElement("div");
|
||||||
|
secondEllipsis.classList.add("lds-ellipsis-item");
|
||||||
|
|
||||||
|
let thirdEllipsis = document.createElement("div");
|
||||||
|
thirdEllipsis.classList.add("lds-ellipsis-item");
|
||||||
|
|
||||||
|
let fourthEllipsis = document.createElement("div");
|
||||||
|
fourthEllipsis.classList.add("lds-ellipsis-item");
|
||||||
|
|
||||||
|
loadingEllipsis.appendChild(firstEllipsis);
|
||||||
|
loadingEllipsis.appendChild(secondEllipsis);
|
||||||
|
loadingEllipsis.appendChild(thirdEllipsis);
|
||||||
|
loadingEllipsis.appendChild(fourthEllipsis);
|
||||||
|
|
||||||
|
return loadingEllipsis;
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleStreamResponse(newResponseElement, rawResponse, rawQuery, loadingEllipsis, replace=true) {
|
||||||
|
if (!newResponseElement) return;
|
||||||
|
// Remove loading ellipsis if it exists
|
||||||
|
if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis)
|
||||||
|
newResponseElement.removeChild(loadingEllipsis);
|
||||||
|
// Clear the response element if replace is true
|
||||||
|
if (replace) newResponseElement.innerHTML = "";
|
||||||
|
|
||||||
|
// Append response to the response element
|
||||||
|
newResponseElement.appendChild(formatHTMLMessage(rawResponse, false, replace, rawQuery));
|
||||||
|
|
||||||
|
// Append loading ellipsis if it exists
|
||||||
|
if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis);
|
||||||
|
// Scroll to bottom of chat view
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleImageResponse(imageJson, rawResponse) {
|
||||||
|
if (imageJson.image) {
|
||||||
|
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
|
||||||
|
|
||||||
|
// If response has image field, response is a generated image.
|
||||||
|
if (imageJson.intentType === "text-to-image") {
|
||||||
|
rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`;
|
||||||
|
} else if (imageJson.intentType === "text-to-image2") {
|
||||||
|
rawResponse += `![generated_image](${imageJson.image})`;
|
||||||
|
} else if (imageJson.intentType === "text-to-image-v3") {
|
||||||
|
rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
|
||||||
|
}
|
||||||
|
if (inferredQuery) {
|
||||||
|
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If response has detail field, response is an error message.
|
||||||
|
if (imageJson.detail) rawResponse += imageJson.detail;
|
||||||
|
|
||||||
|
return rawResponse;
|
||||||
|
}
|
||||||
|
|
||||||
|
function finalizeChatBodyResponse(references, newResponseElement) {
|
||||||
|
if (!!newResponseElement && references != null && Object.keys(references).length > 0) {
|
||||||
|
newResponseElement.appendChild(createReferenceSection(references));
|
||||||
|
}
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
document.getElementById("chat-input")?.removeAttribute("disabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
function convertMessageChunkToJson(rawChunk) {
|
||||||
|
// Split the chunk into lines
|
||||||
|
if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) {
|
||||||
|
try {
|
||||||
|
let jsonChunk = JSON.parse(rawChunk);
|
||||||
|
if (!jsonChunk.type)
|
||||||
|
jsonChunk = {type: 'message', data: jsonChunk};
|
||||||
|
return jsonChunk;
|
||||||
|
} catch (e) {
|
||||||
|
return {type: 'message', data: rawChunk};
|
||||||
|
}
|
||||||
|
} else if (rawChunk.length > 0) {
|
||||||
|
return {type: 'message', data: rawChunk};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function processMessageChunk(rawChunk) {
|
||||||
|
const chunk = convertMessageChunkToJson(rawChunk);
|
||||||
|
console.debug("Chunk:", chunk);
|
||||||
|
if (!chunk || !chunk.type) return;
|
||||||
|
if (chunk.type ==='status') {
|
||||||
|
console.log(`status: ${chunk.data}`);
|
||||||
|
const statusMessage = chunk.data;
|
||||||
|
handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, false);
|
||||||
|
} else if (chunk.type === 'start_llm_response') {
|
||||||
|
console.log("Started streaming", new Date());
|
||||||
|
} else if (chunk.type === 'end_llm_response') {
|
||||||
|
console.log("Stopped streaming", new Date());
|
||||||
|
|
||||||
|
// Automatically respond with voice if the subscribed user has sent voice message
|
||||||
|
if (chatMessageState.isVoice && "{{ is_active }}" == "True")
|
||||||
|
textToSpeech(chatMessageState.rawResponse);
|
||||||
|
|
||||||
|
// Append any references after all the data has been streamed
|
||||||
|
finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl);
|
||||||
|
|
||||||
|
const liveQuery = chatMessageState.rawQuery;
|
||||||
|
// Reset variables
|
||||||
|
chatMessageState = {
|
||||||
|
newResponseTextEl: null,
|
||||||
|
newResponseEl: null,
|
||||||
|
loadingEllipsis: null,
|
||||||
|
references: {},
|
||||||
|
rawResponse: "",
|
||||||
|
rawQuery: liveQuery,
|
||||||
|
isVoice: false,
|
||||||
|
}
|
||||||
|
} else if (chunk.type === "references") {
|
||||||
|
chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext};
|
||||||
|
} else if (chunk.type === 'message') {
|
||||||
|
const chunkData = chunk.data;
|
||||||
|
if (typeof chunkData === 'object' && chunkData !== null) {
|
||||||
|
// If chunkData is already a JSON object
|
||||||
|
handleJsonResponse(chunkData);
|
||||||
|
} else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) {
|
||||||
|
// Try process chunk data as if it is a JSON object
|
||||||
|
try {
|
||||||
|
const jsonData = JSON.parse(chunkData.trim());
|
||||||
|
handleJsonResponse(jsonData);
|
||||||
|
} catch (e) {
|
||||||
|
chatMessageState.rawResponse += chunkData;
|
||||||
|
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
chatMessageState.rawResponse += chunkData;
|
||||||
|
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleJsonResponse(jsonData) {
|
||||||
|
if (jsonData.image || jsonData.detail) {
|
||||||
|
chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse);
|
||||||
|
} else if (jsonData.response) {
|
||||||
|
chatMessageState.rawResponse = jsonData.response;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chatMessageState.newResponseTextEl) {
|
||||||
|
chatMessageState.newResponseTextEl.innerHTML = "";
|
||||||
|
chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function readChatStream(response) {
|
||||||
|
if (!response.body) return;
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
const eventDelimiter = '␃🔚␗';
|
||||||
|
let buffer = '';
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const { value, done } = await reader.read();
|
||||||
|
// If the stream is done
|
||||||
|
if (done) {
|
||||||
|
// Process the last chunk
|
||||||
|
processMessageChunk(buffer);
|
||||||
|
buffer = '';
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read chunk from stream and append it to the buffer
|
||||||
|
const chunk = decoder.decode(value, { stream: true });
|
||||||
|
console.debug("Raw Chunk:", chunk)
|
||||||
|
// Start buffering chunks until complete event is received
|
||||||
|
buffer += chunk;
|
||||||
|
|
||||||
|
// Once the buffer contains a complete event
|
||||||
|
let newEventIndex;
|
||||||
|
while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
|
||||||
|
// Extract the event from the buffer
|
||||||
|
const event = buffer.slice(0, newEventIndex);
|
||||||
|
buffer = buffer.slice(newEventIndex + eventDelimiter.length);
|
||||||
|
|
||||||
|
// Process the event
|
||||||
|
if (event) processMessageChunk(event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -346,7 +346,7 @@
|
||||||
inp.focus();
|
inp.focus();
|
||||||
}
|
}
|
||||||
|
|
||||||
async function chat() {
|
async function chat(isVoice=false) {
|
||||||
//set chat body to empty
|
//set chat body to empty
|
||||||
let chatBody = document.getElementById("chat-body");
|
let chatBody = document.getElementById("chat-body");
|
||||||
chatBody.innerHTML = "";
|
chatBody.innerHTML = "";
|
||||||
|
@ -375,9 +375,6 @@
|
||||||
chat_body.dataset.conversationId = conversationID;
|
chat_body.dataset.conversationId = conversationID;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate backend API URL to execute query
|
|
||||||
let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`;
|
|
||||||
|
|
||||||
let newResponseEl = document.createElement("div");
|
let newResponseEl = document.createElement("div");
|
||||||
newResponseEl.classList.add("chat-message", "khoj");
|
newResponseEl.classList.add("chat-message", "khoj");
|
||||||
newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
|
newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
|
||||||
|
@ -388,128 +385,41 @@
|
||||||
newResponseEl.appendChild(newResponseTextEl);
|
newResponseEl.appendChild(newResponseTextEl);
|
||||||
|
|
||||||
// Temporary status message to indicate that Khoj is thinking
|
// Temporary status message to indicate that Khoj is thinking
|
||||||
let loadingEllipsis = document.createElement("div");
|
let loadingEllipsis = createLoadingEllipsis();
|
||||||
loadingEllipsis.classList.add("lds-ellipsis");
|
|
||||||
|
|
||||||
let firstEllipsis = document.createElement("div");
|
|
||||||
firstEllipsis.classList.add("lds-ellipsis-item");
|
|
||||||
|
|
||||||
let secondEllipsis = document.createElement("div");
|
|
||||||
secondEllipsis.classList.add("lds-ellipsis-item");
|
|
||||||
|
|
||||||
let thirdEllipsis = document.createElement("div");
|
|
||||||
thirdEllipsis.classList.add("lds-ellipsis-item");
|
|
||||||
|
|
||||||
let fourthEllipsis = document.createElement("div");
|
|
||||||
fourthEllipsis.classList.add("lds-ellipsis-item");
|
|
||||||
|
|
||||||
loadingEllipsis.appendChild(firstEllipsis);
|
|
||||||
loadingEllipsis.appendChild(secondEllipsis);
|
|
||||||
loadingEllipsis.appendChild(thirdEllipsis);
|
|
||||||
loadingEllipsis.appendChild(fourthEllipsis);
|
|
||||||
|
|
||||||
newResponseTextEl.appendChild(loadingEllipsis);
|
|
||||||
document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
|
document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
|
||||||
// Call Khoj chat API
|
|
||||||
let response = await fetch(chatApi, { headers });
|
|
||||||
let rawResponse = "";
|
|
||||||
let references = null;
|
|
||||||
const contentType = response.headers.get("content-type");
|
|
||||||
toggleLoading();
|
toggleLoading();
|
||||||
if (contentType === "application/json") {
|
|
||||||
// Handle JSON response
|
// Setup chat message state
|
||||||
|
chatMessageState = {
|
||||||
|
newResponseTextEl,
|
||||||
|
newResponseEl,
|
||||||
|
loadingEllipsis,
|
||||||
|
references: {},
|
||||||
|
rawResponse: "",
|
||||||
|
rawQuery: query,
|
||||||
|
isVoice: isVoice,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct API URL to execute chat query
|
||||||
|
let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationID}&stream=true&client=desktop`;
|
||||||
|
chatApi += (!!region && !!city && !!countryName && !!timezone)
|
||||||
|
? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`
|
||||||
|
: '';
|
||||||
|
|
||||||
|
const response = await fetch(chatApi, { headers });
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const responseAsJson = await response.json();
|
if (!response.ok) throw new Error(response.statusText);
|
||||||
if (responseAsJson.image) {
|
if (!response.body) throw new Error("Response body is empty");
|
||||||
// If response has image field, response is a generated image.
|
// Stream and render chat response
|
||||||
if (responseAsJson.intentType === "text-to-image") {
|
await readChatStream(response);
|
||||||
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
|
} catch (err) {
|
||||||
} else if (responseAsJson.intentType === "text-to-image2") {
|
console.error(`Khoj chat response failed with\n${err}`);
|
||||||
rawResponse += `![${query}](${responseAsJson.image})`;
|
if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis)
|
||||||
} else if (responseAsJson.intentType === "text-to-image-v3") {
|
chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
|
||||||
rawResponse += `![${query}](data:image/webp;base64,${responseAsJson.image})`;
|
let errorMsg = "Sorry, unable to get response from Khoj backend ❤️🩹. Retry or contact developers for help at <a href=mailto:'team@khoj.dev'>team@khoj.dev</a> or <a href='https://discord.gg/BDgyabRM6e'>on Discord</a>";
|
||||||
}
|
newResponseTextEl.textContent = errorMsg;
|
||||||
const inferredQueries = responseAsJson.inferredQueries?.[0];
|
|
||||||
if (inferredQueries) {
|
|
||||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (responseAsJson.context) {
|
|
||||||
const rawReferenceAsJson = responseAsJson.context;
|
|
||||||
references = createReferenceSection(rawReferenceAsJson, createLinkerSection=true);
|
|
||||||
}
|
|
||||||
if (responseAsJson.detail) {
|
|
||||||
// If response has detail field, response is an error message.
|
|
||||||
rawResponse += responseAsJson.detail;
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
rawResponse += chunk;
|
|
||||||
} finally {
|
|
||||||
newResponseTextEl.innerHTML = "";
|
|
||||||
newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
if (references != null) {
|
|
||||||
newResponseTextEl.appendChild(references);
|
|
||||||
}
|
|
||||||
|
|
||||||
document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Handle streamed response of type text/event-stream or text/plain
|
|
||||||
const reader = response.body.getReader();
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
let references = {};
|
|
||||||
|
|
||||||
readStream();
|
|
||||||
|
|
||||||
function readStream() {
|
|
||||||
reader.read().then(({ done, value }) => {
|
|
||||||
if (done) {
|
|
||||||
// Append any references after all the data has been streamed
|
|
||||||
if (references != {}) {
|
|
||||||
newResponseTextEl.appendChild(createReferenceSection(references, createLinkerSection=true));
|
|
||||||
}
|
|
||||||
document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode message chunk from stream
|
|
||||||
const chunk = decoder.decode(value, { stream: true });
|
|
||||||
|
|
||||||
if (chunk.includes("### compiled references:")) {
|
|
||||||
const additionalResponse = chunk.split("### compiled references:")[0];
|
|
||||||
rawResponse += additionalResponse;
|
|
||||||
newResponseTextEl.innerHTML = "";
|
|
||||||
newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
const rawReference = chunk.split("### compiled references:")[1];
|
|
||||||
const rawReferenceAsJson = JSON.parse(rawReference);
|
|
||||||
if (rawReferenceAsJson instanceof Array) {
|
|
||||||
references["notes"] = rawReferenceAsJson;
|
|
||||||
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
|
||||||
references["online"] = rawReferenceAsJson;
|
|
||||||
}
|
|
||||||
readStream();
|
|
||||||
} else {
|
|
||||||
// Display response from Khoj
|
|
||||||
if (newResponseTextEl.getElementsByClassName("lds-ellipsis").length > 0) {
|
|
||||||
newResponseTextEl.removeChild(loadingEllipsis);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
rawResponse += chunk;
|
|
||||||
newResponseTextEl.innerHTML = "";
|
|
||||||
newResponseTextEl.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
readStream();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scroll to bottom of chat window as chat response is streamed
|
|
||||||
document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
|
document.body.scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,8 +34,8 @@ function toggleNavMenu() {
|
||||||
document.addEventListener('click', function(event) {
|
document.addEventListener('click', function(event) {
|
||||||
let menu = document.getElementById("khoj-nav-menu");
|
let menu = document.getElementById("khoj-nav-menu");
|
||||||
let menuContainer = document.getElementById("khoj-nav-menu-container");
|
let menuContainer = document.getElementById("khoj-nav-menu-container");
|
||||||
let isClickOnMenu = menuContainer.contains(event.target) || menuContainer === event.target;
|
let isClickOnMenu = menuContainer?.contains(event.target) || menuContainer === event.target;
|
||||||
if (isClickOnMenu === false && menu.classList.contains("show")) {
|
if (menu && isClickOnMenu === false && menu.classList.contains("show")) {
|
||||||
menu.classList.remove("show");
|
menu.classList.remove("show");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -12,6 +12,25 @@ export interface ChatJsonResult {
|
||||||
inferredQueries?: string[];
|
inferredQueries?: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface ChunkResult {
|
||||||
|
objects: string[];
|
||||||
|
remainder: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface MessageChunk {
|
||||||
|
type: string;
|
||||||
|
data: any;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ChatMessageState {
|
||||||
|
newResponseTextEl: HTMLElement | null;
|
||||||
|
newResponseEl: HTMLElement | null;
|
||||||
|
loadingEllipsis: HTMLElement | null;
|
||||||
|
references: any;
|
||||||
|
rawResponse: string;
|
||||||
|
rawQuery: string;
|
||||||
|
isVoice: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
interface Location {
|
interface Location {
|
||||||
region: string;
|
region: string;
|
||||||
|
@ -26,6 +45,7 @@ export class KhojChatView extends KhojPaneView {
|
||||||
waitingForLocation: boolean;
|
waitingForLocation: boolean;
|
||||||
location: Location;
|
location: Location;
|
||||||
keyPressTimeout: NodeJS.Timeout | null = null;
|
keyPressTimeout: NodeJS.Timeout | null = null;
|
||||||
|
chatMessageState: ChatMessageState;
|
||||||
|
|
||||||
constructor(leaf: WorkspaceLeaf, setting: KhojSetting) {
|
constructor(leaf: WorkspaceLeaf, setting: KhojSetting) {
|
||||||
super(leaf, setting);
|
super(leaf, setting);
|
||||||
|
@ -409,16 +429,15 @@ export class KhojChatView extends KhojPaneView {
|
||||||
message = DOMPurify.sanitize(message);
|
message = DOMPurify.sanitize(message);
|
||||||
|
|
||||||
// Convert the message to html, sanitize the message html and render it to the real DOM
|
// Convert the message to html, sanitize the message html and render it to the real DOM
|
||||||
let chat_message_body_text_el = this.contentEl.createDiv();
|
let chatMessageBodyTextEl = this.contentEl.createDiv();
|
||||||
chat_message_body_text_el.className = "chat-message-text-response";
|
chatMessageBodyTextEl.innerHTML = this.markdownTextToSanitizedHtml(message, this);
|
||||||
chat_message_body_text_el.innerHTML = this.markdownTextToSanitizedHtml(message, this);
|
|
||||||
|
|
||||||
// Add a copy button to each chat message, if it doesn't already exist
|
// Add a copy button to each chat message, if it doesn't already exist
|
||||||
if (willReplace === true) {
|
if (willReplace === true) {
|
||||||
this.renderActionButtons(message, chat_message_body_text_el);
|
this.renderActionButtons(message, chatMessageBodyTextEl);
|
||||||
}
|
}
|
||||||
|
|
||||||
return chat_message_body_text_el;
|
return chatMessageBodyTextEl;
|
||||||
}
|
}
|
||||||
|
|
||||||
markdownTextToSanitizedHtml(markdownText: string, component: ItemView): string {
|
markdownTextToSanitizedHtml(markdownText: string, component: ItemView): string {
|
||||||
|
@ -502,23 +521,23 @@ export class KhojChatView extends KhojPaneView {
|
||||||
class: `khoj-chat-message ${sender}`
|
class: `khoj-chat-message ${sender}`
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
let chat_message_body_el = chatMessageEl.createDiv();
|
let chatMessageBodyEl = chatMessageEl.createDiv();
|
||||||
chat_message_body_el.addClasses(["khoj-chat-message-text", sender]);
|
chatMessageBodyEl.addClasses(["khoj-chat-message-text", sender]);
|
||||||
let chat_message_body_text_el = chat_message_body_el.createDiv();
|
let chatMessageBodyTextEl = chatMessageBodyEl.createDiv();
|
||||||
|
|
||||||
// Sanitize the markdown to render
|
// Sanitize the markdown to render
|
||||||
message = DOMPurify.sanitize(message);
|
message = DOMPurify.sanitize(message);
|
||||||
|
|
||||||
if (raw) {
|
if (raw) {
|
||||||
chat_message_body_text_el.innerHTML = message;
|
chatMessageBodyTextEl.innerHTML = message;
|
||||||
} else {
|
} else {
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
chat_message_body_text_el.innerHTML = this.markdownTextToSanitizedHtml(message, this);
|
chatMessageBodyTextEl.innerHTML = this.markdownTextToSanitizedHtml(message, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add action buttons to each chat message element
|
// Add action buttons to each chat message element
|
||||||
if (willReplace === true) {
|
if (willReplace === true) {
|
||||||
this.renderActionButtons(message, chat_message_body_text_el);
|
this.renderActionButtons(message, chatMessageBodyTextEl);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove user-select: none property to make text selectable
|
// Remove user-select: none property to make text selectable
|
||||||
|
@ -531,42 +550,38 @@ export class KhojChatView extends KhojPaneView {
|
||||||
}
|
}
|
||||||
|
|
||||||
createKhojResponseDiv(dt?: Date): HTMLDivElement {
|
createKhojResponseDiv(dt?: Date): HTMLDivElement {
|
||||||
let message_time = this.formatDate(dt ?? new Date());
|
let messageTime = this.formatDate(dt ?? new Date());
|
||||||
|
|
||||||
// Append message to conversation history HTML element.
|
// Append message to conversation history HTML element.
|
||||||
// The chat logs should display above the message input box to follow standard UI semantics
|
// The chat logs should display above the message input box to follow standard UI semantics
|
||||||
let chat_body_el = this.contentEl.getElementsByClassName("khoj-chat-body")[0];
|
let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0];
|
||||||
let chat_message_el = chat_body_el.createDiv({
|
let chatMessageEl = chatBodyEl.createDiv({
|
||||||
attr: {
|
attr: {
|
||||||
"data-meta": `🏮 Khoj at ${message_time}`,
|
"data-meta": `🏮 Khoj at ${messageTime}`,
|
||||||
class: `khoj-chat-message khoj`
|
class: `khoj-chat-message khoj`
|
||||||
},
|
},
|
||||||
}).createDiv({
|
})
|
||||||
attr: {
|
|
||||||
class: `khoj-chat-message-text khoj`
|
|
||||||
},
|
|
||||||
}).createDiv();
|
|
||||||
|
|
||||||
// Scroll to bottom after inserting chat messages
|
// Scroll to bottom after inserting chat messages
|
||||||
this.scrollChatToBottom();
|
this.scrollChatToBottom();
|
||||||
|
|
||||||
return chat_message_el;
|
return chatMessageEl;
|
||||||
}
|
}
|
||||||
|
|
||||||
async renderIncrementalMessage(htmlElement: HTMLDivElement, additionalMessage: string) {
|
async renderIncrementalMessage(htmlElement: HTMLDivElement, additionalMessage: string) {
|
||||||
this.result += additionalMessage;
|
this.chatMessageState.rawResponse += additionalMessage;
|
||||||
htmlElement.innerHTML = "";
|
htmlElement.innerHTML = "";
|
||||||
// Sanitize the markdown to render
|
// Sanitize the markdown to render
|
||||||
this.result = DOMPurify.sanitize(this.result);
|
this.chatMessageState.rawResponse = DOMPurify.sanitize(this.chatMessageState.rawResponse);
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
htmlElement.innerHTML = this.markdownTextToSanitizedHtml(this.result, this);
|
htmlElement.innerHTML = this.markdownTextToSanitizedHtml(this.chatMessageState.rawResponse, this);
|
||||||
// Render action buttons for the message
|
// Render action buttons for the message
|
||||||
this.renderActionButtons(this.result, htmlElement);
|
this.renderActionButtons(this.chatMessageState.rawResponse, htmlElement);
|
||||||
// Scroll to bottom of modal, till the send message input box
|
// Scroll to bottom of modal, till the send message input box
|
||||||
this.scrollChatToBottom();
|
this.scrollChatToBottom();
|
||||||
}
|
}
|
||||||
|
|
||||||
renderActionButtons(message: string, chat_message_body_text_el: HTMLElement) {
|
renderActionButtons(message: string, chatMessageBodyTextEl: HTMLElement) {
|
||||||
let copyButton = this.contentEl.createEl('button');
|
let copyButton = this.contentEl.createEl('button');
|
||||||
copyButton.classList.add("chat-action-button");
|
copyButton.classList.add("chat-action-button");
|
||||||
copyButton.title = "Copy Message to Clipboard";
|
copyButton.title = "Copy Message to Clipboard";
|
||||||
|
@ -593,10 +608,10 @@ export class KhojChatView extends KhojPaneView {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append buttons to parent element
|
// Append buttons to parent element
|
||||||
chat_message_body_text_el.append(copyButton, pasteToFile);
|
chatMessageBodyTextEl.append(copyButton, pasteToFile);
|
||||||
|
|
||||||
if (speechButton) {
|
if (speechButton) {
|
||||||
chat_message_body_text_el.append(speechButton);
|
chatMessageBodyTextEl.append(speechButton);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -854,35 +869,122 @@ export class KhojChatView extends KhojPaneView {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
async readChatStream(response: Response, responseElement: HTMLDivElement, isVoice: boolean = false): Promise<void> {
|
convertMessageChunkToJson(rawChunk: string): MessageChunk {
|
||||||
|
if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) {
|
||||||
|
try {
|
||||||
|
let jsonChunk = JSON.parse(rawChunk);
|
||||||
|
if (!jsonChunk.type)
|
||||||
|
jsonChunk = {type: 'message', data: jsonChunk};
|
||||||
|
return jsonChunk;
|
||||||
|
} catch (e) {
|
||||||
|
return {type: 'message', data: rawChunk};
|
||||||
|
}
|
||||||
|
} else if (rawChunk.length > 0) {
|
||||||
|
return {type: 'message', data: rawChunk};
|
||||||
|
}
|
||||||
|
return {type: '', data: ''};
|
||||||
|
}
|
||||||
|
|
||||||
|
processMessageChunk(rawChunk: string): void {
|
||||||
|
const chunk = this.convertMessageChunkToJson(rawChunk);
|
||||||
|
console.debug("Chunk:", chunk);
|
||||||
|
if (!chunk || !chunk.type) return;
|
||||||
|
if (chunk.type === 'status') {
|
||||||
|
console.log(`status: ${chunk.data}`);
|
||||||
|
const statusMessage = chunk.data;
|
||||||
|
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, statusMessage, this.chatMessageState.loadingEllipsis, false);
|
||||||
|
} else if (chunk.type === 'start_llm_response') {
|
||||||
|
console.log("Started streaming", new Date());
|
||||||
|
} else if (chunk.type === 'end_llm_response') {
|
||||||
|
console.log("Stopped streaming", new Date());
|
||||||
|
|
||||||
|
// Automatically respond with voice if the subscribed user has sent voice message
|
||||||
|
if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active)
|
||||||
|
this.textToSpeech(this.chatMessageState.rawResponse);
|
||||||
|
|
||||||
|
// Append any references after all the data has been streamed
|
||||||
|
this.finalizeChatBodyResponse(this.chatMessageState.references, this.chatMessageState.newResponseTextEl);
|
||||||
|
|
||||||
|
const liveQuery = this.chatMessageState.rawQuery;
|
||||||
|
// Reset variables
|
||||||
|
this.chatMessageState = {
|
||||||
|
newResponseTextEl: null,
|
||||||
|
newResponseEl: null,
|
||||||
|
loadingEllipsis: null,
|
||||||
|
references: {},
|
||||||
|
rawResponse: "",
|
||||||
|
rawQuery: liveQuery,
|
||||||
|
isVoice: false,
|
||||||
|
};
|
||||||
|
} else if (chunk.type === "references") {
|
||||||
|
this.chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext};
|
||||||
|
} else if (chunk.type === 'message') {
|
||||||
|
const chunkData = chunk.data;
|
||||||
|
if (typeof chunkData === 'object' && chunkData !== null) {
|
||||||
|
// If chunkData is already a JSON object
|
||||||
|
this.handleJsonResponse(chunkData);
|
||||||
|
} else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) {
|
||||||
|
// Try process chunk data as if it is a JSON object
|
||||||
|
try {
|
||||||
|
const jsonData = JSON.parse(chunkData.trim());
|
||||||
|
this.handleJsonResponse(jsonData);
|
||||||
|
} catch (e) {
|
||||||
|
this.chatMessageState.rawResponse += chunkData;
|
||||||
|
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
this.chatMessageState.rawResponse += chunkData;
|
||||||
|
this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
handleJsonResponse(jsonData: any): void {
|
||||||
|
if (jsonData.image || jsonData.detail) {
|
||||||
|
this.chatMessageState.rawResponse = this.handleImageResponse(jsonData, this.chatMessageState.rawResponse);
|
||||||
|
} else if (jsonData.response) {
|
||||||
|
this.chatMessageState.rawResponse = jsonData.response;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.chatMessageState.newResponseTextEl) {
|
||||||
|
this.chatMessageState.newResponseTextEl.innerHTML = "";
|
||||||
|
this.chatMessageState.newResponseTextEl.appendChild(this.formatHTMLMessage(this.chatMessageState.rawResponse));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async readChatStream(response: Response): Promise<void> {
|
||||||
// Exit if response body is empty
|
// Exit if response body is empty
|
||||||
if (response.body == null) return;
|
if (response.body == null) return;
|
||||||
|
|
||||||
const reader = response.body.getReader();
|
const reader = response.body.getReader();
|
||||||
const decoder = new TextDecoder();
|
const decoder = new TextDecoder();
|
||||||
|
const eventDelimiter = '␃🔚␗';
|
||||||
|
let buffer = '';
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
const { value, done } = await reader.read();
|
const { value, done } = await reader.read();
|
||||||
|
|
||||||
if (done) {
|
if (done) {
|
||||||
// Automatically respond with voice if the subscribed user has sent voice message
|
this.processMessageChunk(buffer);
|
||||||
if (isVoice && this.setting.userInfo?.is_active) this.textToSpeech(this.result);
|
buffer = '';
|
||||||
// Break if the stream is done
|
// Break if the stream is done
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
let responseText = decoder.decode(value);
|
const chunk = decoder.decode(value, { stream: true });
|
||||||
if (responseText.includes("### compiled references:")) {
|
console.debug("Raw Chunk:", chunk)
|
||||||
// Render any references used to generate the response
|
// Start buffering chunks until complete event is received
|
||||||
const [additionalResponse, rawReference] = responseText.split("### compiled references:", 2);
|
buffer += chunk;
|
||||||
await this.renderIncrementalMessage(responseElement, additionalResponse);
|
|
||||||
|
|
||||||
const rawReferenceAsJson = JSON.parse(rawReference);
|
// Once the buffer contains a complete event
|
||||||
let references = this.extractReferences(rawReferenceAsJson);
|
let newEventIndex;
|
||||||
responseElement.appendChild(this.createReferenceSection(references));
|
while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
|
||||||
} else {
|
// Extract the event from the buffer
|
||||||
// Render incremental chat response
|
const event = buffer.slice(0, newEventIndex);
|
||||||
await this.renderIncrementalMessage(responseElement, responseText);
|
buffer = buffer.slice(newEventIndex + eventDelimiter.length);
|
||||||
|
|
||||||
|
// Process the event
|
||||||
|
if (event) this.processMessageChunk(event);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -895,83 +997,59 @@ export class KhojChatView extends KhojPaneView {
|
||||||
let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0] as HTMLElement;
|
let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0] as HTMLElement;
|
||||||
this.renderMessage(chatBodyEl, query, "you");
|
this.renderMessage(chatBodyEl, query, "you");
|
||||||
|
|
||||||
let conversationID = chatBodyEl.dataset.conversationId;
|
let conversationId = chatBodyEl.dataset.conversationId;
|
||||||
if (!conversationID) {
|
if (!conversationId) {
|
||||||
let chatUrl = `${this.setting.khojUrl}/api/chat/sessions?client=obsidian`;
|
let chatUrl = `${this.setting.khojUrl}/api/chat/sessions?client=obsidian`;
|
||||||
let response = await fetch(chatUrl, {
|
let response = await fetch(chatUrl, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: { "Authorization": `Bearer ${this.setting.khojApiKey}` },
|
headers: { "Authorization": `Bearer ${this.setting.khojApiKey}` },
|
||||||
});
|
});
|
||||||
let data = await response.json();
|
let data = await response.json();
|
||||||
conversationID = data.conversation_id;
|
conversationId = data.conversation_id;
|
||||||
chatBodyEl.dataset.conversationId = conversationID;
|
chatBodyEl.dataset.conversationId = conversationId;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get chat response from Khoj backend
|
// Get chat response from Khoj backend
|
||||||
let encodedQuery = encodeURIComponent(query);
|
let encodedQuery = encodeURIComponent(query);
|
||||||
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian&stream=true®ion=${this.location.region}&city=${this.location.city}&country=${this.location.countryName}&timezone=${this.location.timezone}`;
|
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&conversation_id=${conversationId}&n=${this.setting.resultsCount}&stream=true&client=obsidian`;
|
||||||
let responseElement = this.createKhojResponseDiv();
|
if (!!this.location) chatUrl += `®ion=${this.location.region}&city=${this.location.city}&country=${this.location.countryName}&timezone=${this.location.timezone}`;
|
||||||
|
|
||||||
|
let newResponseEl = this.createKhojResponseDiv();
|
||||||
|
let newResponseTextEl = newResponseEl.createDiv();
|
||||||
|
newResponseTextEl.classList.add("khoj-chat-message-text", "khoj");
|
||||||
|
|
||||||
// Temporary status message to indicate that Khoj is thinking
|
// Temporary status message to indicate that Khoj is thinking
|
||||||
this.result = "";
|
|
||||||
let loadingEllipsis = this.createLoadingEllipse();
|
let loadingEllipsis = this.createLoadingEllipse();
|
||||||
responseElement.appendChild(loadingEllipsis);
|
newResponseTextEl.appendChild(loadingEllipsis);
|
||||||
|
|
||||||
|
// Set chat message state
|
||||||
|
this.chatMessageState = {
|
||||||
|
newResponseEl: newResponseEl,
|
||||||
|
newResponseTextEl: newResponseTextEl,
|
||||||
|
loadingEllipsis: loadingEllipsis,
|
||||||
|
references: {},
|
||||||
|
rawQuery: query,
|
||||||
|
rawResponse: "",
|
||||||
|
isVoice: isVoice,
|
||||||
|
};
|
||||||
|
|
||||||
let response = await fetch(chatUrl, {
|
let response = await fetch(chatUrl, {
|
||||||
method: "GET",
|
method: "GET",
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "text/event-stream",
|
"Content-Type": "text/plain",
|
||||||
"Authorization": `Bearer ${this.setting.khojApiKey}`,
|
"Authorization": `Bearer ${this.setting.khojApiKey}`,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (response.body === null) {
|
if (response.body === null) throw new Error("Response body is null");
|
||||||
throw new Error("Response body is null");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear loading status message
|
|
||||||
if (responseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) {
|
|
||||||
responseElement.removeChild(loadingEllipsis);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset collated chat result to empty string
|
|
||||||
this.result = "";
|
|
||||||
responseElement.innerHTML = "";
|
|
||||||
if (response.headers.get("content-type") === "application/json") {
|
|
||||||
let responseText = ""
|
|
||||||
try {
|
|
||||||
const responseAsJson = await response.json() as ChatJsonResult;
|
|
||||||
if (responseAsJson.image) {
|
|
||||||
// If response has image field, response is a generated image.
|
|
||||||
if (responseAsJson.intentType === "text-to-image") {
|
|
||||||
responseText += `![${query}](data:image/png;base64,${responseAsJson.image})`;
|
|
||||||
} else if (responseAsJson.intentType === "text-to-image2") {
|
|
||||||
responseText += `![${query}](${responseAsJson.image})`;
|
|
||||||
} else if (responseAsJson.intentType === "text-to-image-v3") {
|
|
||||||
responseText += `![${query}](data:image/webp;base64,${responseAsJson.image})`;
|
|
||||||
}
|
|
||||||
const inferredQuery = responseAsJson.inferredQueries?.[0];
|
|
||||||
if (inferredQuery) {
|
|
||||||
responseText += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
|
||||||
}
|
|
||||||
} else if (responseAsJson.detail) {
|
|
||||||
responseText = responseAsJson.detail;
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
responseText = await response.text();
|
|
||||||
} finally {
|
|
||||||
await this.renderIncrementalMessage(responseElement, responseText);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Stream and render chat response
|
// Stream and render chat response
|
||||||
await this.readChatStream(response, responseElement, isVoice);
|
await this.readChatStream(response);
|
||||||
}
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.log(`Khoj chat response failed with\n${err}`);
|
console.error(`Khoj chat response failed with\n${err}`);
|
||||||
let errorMsg = "Sorry, unable to get response from Khoj backend ❤️🩹. Retry or contact developers for help at <a href=mailto:'team@khoj.dev'>team@khoj.dev</a> or <a href='https://discord.gg/BDgyabRM6e'>on Discord</a>";
|
let errorMsg = "Sorry, unable to get response from Khoj backend ❤️🩹. Retry or contact developers for help at <a href=mailto:'team@khoj.dev'>team@khoj.dev</a> or <a href='https://discord.gg/BDgyabRM6e'>on Discord</a>";
|
||||||
responseElement.innerHTML = errorMsg
|
newResponseTextEl.textContent = errorMsg;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1196,30 +1274,21 @@ export class KhojChatView extends KhojPaneView {
|
||||||
|
|
||||||
handleStreamResponse(newResponseElement: HTMLElement | null, rawResponse: string, loadingEllipsis: HTMLElement | null, replace = true) {
|
handleStreamResponse(newResponseElement: HTMLElement | null, rawResponse: string, loadingEllipsis: HTMLElement | null, replace = true) {
|
||||||
if (!newResponseElement) return;
|
if (!newResponseElement) return;
|
||||||
if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) {
|
// Remove loading ellipsis if it exists
|
||||||
|
if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis)
|
||||||
newResponseElement.removeChild(loadingEllipsis);
|
newResponseElement.removeChild(loadingEllipsis);
|
||||||
}
|
// Clear the response element if replace is true
|
||||||
if (replace) {
|
if (replace) newResponseElement.innerHTML = "";
|
||||||
newResponseElement.innerHTML = "";
|
|
||||||
}
|
// Append response to the response element
|
||||||
newResponseElement.appendChild(this.formatHTMLMessage(rawResponse, false, replace));
|
newResponseElement.appendChild(this.formatHTMLMessage(rawResponse, false, replace));
|
||||||
|
|
||||||
|
// Append loading ellipsis if it exists
|
||||||
|
if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis);
|
||||||
|
// Scroll to bottom of chat view
|
||||||
this.scrollChatToBottom();
|
this.scrollChatToBottom();
|
||||||
}
|
}
|
||||||
|
|
||||||
handleCompiledReferences(rawResponseElement: HTMLElement | null, chunk: string, references: any, rawResponse: string) {
|
|
||||||
if (!rawResponseElement || !chunk) return { rawResponse, references };
|
|
||||||
|
|
||||||
const [additionalResponse, rawReference] = chunk.split("### compiled references:", 2);
|
|
||||||
rawResponse += additionalResponse;
|
|
||||||
rawResponseElement.innerHTML = "";
|
|
||||||
rawResponseElement.appendChild(this.formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
const rawReferenceAsJson = JSON.parse(rawReference);
|
|
||||||
references = this.extractReferences(rawReferenceAsJson);
|
|
||||||
|
|
||||||
return { rawResponse, references };
|
|
||||||
}
|
|
||||||
|
|
||||||
handleImageResponse(imageJson: any, rawResponse: string) {
|
handleImageResponse(imageJson: any, rawResponse: string) {
|
||||||
if (imageJson.image) {
|
if (imageJson.image) {
|
||||||
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
|
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
|
||||||
|
@ -1236,33 +1305,10 @@ export class KhojChatView extends KhojPaneView {
|
||||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let references = {};
|
|
||||||
if (imageJson.context && imageJson.context.length > 0) {
|
|
||||||
references = this.extractReferences(imageJson.context);
|
|
||||||
}
|
|
||||||
if (imageJson.detail) {
|
|
||||||
// If response has detail field, response is an error message.
|
// If response has detail field, response is an error message.
|
||||||
rawResponse += imageJson.detail;
|
if (imageJson.detail) rawResponse += imageJson.detail;
|
||||||
}
|
|
||||||
return { rawResponse, references };
|
|
||||||
}
|
|
||||||
|
|
||||||
extractReferences(rawReferenceAsJson: any): object {
|
return rawResponse;
|
||||||
let references: any = {};
|
|
||||||
if (rawReferenceAsJson instanceof Array) {
|
|
||||||
references["notes"] = rawReferenceAsJson;
|
|
||||||
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
|
||||||
references["online"] = rawReferenceAsJson;
|
|
||||||
}
|
|
||||||
return references;
|
|
||||||
}
|
|
||||||
|
|
||||||
addMessageToChatBody(rawResponse: string, newResponseElement: HTMLElement | null, references: any) {
|
|
||||||
if (!newResponseElement) return;
|
|
||||||
newResponseElement.innerHTML = "";
|
|
||||||
newResponseElement.appendChild(this.formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
this.finalizeChatBodyResponse(references, newResponseElement);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
finalizeChatBodyResponse(references: object, newResponseElement: HTMLElement | null) {
|
finalizeChatBodyResponse(references: object, newResponseElement: HTMLElement | null) {
|
||||||
|
|
|
@ -85,6 +85,12 @@ If your plugin does not need CSS, delete this file.
|
||||||
margin-left: auto;
|
margin-left: auto;
|
||||||
white-space: pre-line;
|
white-space: pre-line;
|
||||||
}
|
}
|
||||||
|
/* Override white-space for ul, ol, li under khoj-chat-message-text.khoj */
|
||||||
|
.khoj-chat-message-text.khoj ul,
|
||||||
|
.khoj-chat-message-text.khoj ol,
|
||||||
|
.khoj-chat-message-text.khoj li {
|
||||||
|
white-space: normal;
|
||||||
|
}
|
||||||
/* add left protrusion to khoj chat bubble */
|
/* add left protrusion to khoj chat bubble */
|
||||||
.khoj-chat-message-text.khoj:after {
|
.khoj-chat-message-text.khoj:after {
|
||||||
content: '';
|
content: '';
|
||||||
|
|
|
@ -680,34 +680,18 @@ class ConversationAdapters:
|
||||||
async def aget_conversation_by_user(
|
async def aget_conversation_by_user(
|
||||||
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, title: str = None
|
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, title: str = None
|
||||||
) -> Optional[Conversation]:
|
) -> Optional[Conversation]:
|
||||||
|
query = Conversation.objects.filter(user=user, client=client_application).prefetch_related("agent")
|
||||||
|
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
return (
|
return await query.filter(id=conversation_id).afirst()
|
||||||
await Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
|
|
||||||
.prefetch_related("agent")
|
|
||||||
.afirst()
|
|
||||||
)
|
|
||||||
elif title:
|
elif title:
|
||||||
return (
|
return await query.filter(title=title).afirst()
|
||||||
await Conversation.objects.filter(user=user, client=client_application, title=title)
|
|
||||||
.prefetch_related("agent")
|
|
||||||
.afirst()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
conversation = (
|
|
||||||
Conversation.objects.filter(user=user, client=client_application)
|
|
||||||
.prefetch_related("agent")
|
|
||||||
.order_by("-updated_at")
|
|
||||||
)
|
|
||||||
|
|
||||||
if await conversation.aexists():
|
conversation = await query.order_by("-updated_at").afirst()
|
||||||
return await conversation.prefetch_related("agent").afirst()
|
|
||||||
|
|
||||||
return await (
|
return conversation or await Conversation.objects.prefetch_related("agent").acreate(
|
||||||
Conversation.objects.filter(user=user, client=client_application)
|
user=user, client=client_application
|
||||||
.prefetch_related("agent")
|
)
|
||||||
.order_by("-updated_at")
|
|
||||||
.afirst()
|
|
||||||
) or await Conversation.objects.prefetch_related("agent").acreate(user=user, client=client_application)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def adelete_conversation_by_user(
|
async def adelete_conversation_by_user(
|
||||||
|
|
|
@ -74,14 +74,13 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
}, 1000);
|
}, 1000);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
var websocket = null;
|
|
||||||
let region = null;
|
let region = null;
|
||||||
let city = null;
|
let city = null;
|
||||||
let countryName = null;
|
let countryName = null;
|
||||||
let timezone = null;
|
let timezone = null;
|
||||||
let waitingForLocation = true;
|
let waitingForLocation = true;
|
||||||
|
let chatMessageState = {
|
||||||
let websocketState = {
|
|
||||||
newResponseTextEl: null,
|
newResponseTextEl: null,
|
||||||
newResponseEl: null,
|
newResponseEl: null,
|
||||||
loadingEllipsis: null,
|
loadingEllipsis: null,
|
||||||
|
@ -105,7 +104,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
.finally(() => {
|
.finally(() => {
|
||||||
console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone);
|
console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone);
|
||||||
waitingForLocation = false;
|
waitingForLocation = false;
|
||||||
setupWebSocket();
|
initMessageState();
|
||||||
});
|
});
|
||||||
|
|
||||||
function formatDate(date) {
|
function formatDate(date) {
|
||||||
|
@ -599,13 +598,8 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
}
|
}
|
||||||
|
|
||||||
async function chat(isVoice=false) {
|
async function chat(isVoice=false) {
|
||||||
if (websocket) {
|
// Extract chat message from chat input form
|
||||||
sendMessageViaWebSocket(isVoice);
|
var query = document.getElementById("chat-input").value.trim();
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let query = document.getElementById("chat-input").value.trim();
|
|
||||||
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
|
|
||||||
console.log(`Query: ${query}`);
|
console.log(`Query: ${query}`);
|
||||||
|
|
||||||
// Short circuit on empty query
|
// Short circuit on empty query
|
||||||
|
@ -624,31 +618,30 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
document.getElementById("chat-input").value = "";
|
document.getElementById("chat-input").value = "";
|
||||||
autoResize();
|
autoResize();
|
||||||
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
||||||
let chat_body = document.getElementById("chat-body");
|
|
||||||
|
|
||||||
let conversationID = chat_body.dataset.conversationId;
|
|
||||||
|
|
||||||
|
let chatBody = document.getElementById("chat-body");
|
||||||
|
let conversationID = chatBody.dataset.conversationId;
|
||||||
if (!conversationID) {
|
if (!conversationID) {
|
||||||
let response = await fetch('/api/chat/sessions', { method: "POST" });
|
let response = await fetch(`${hostURL}/api/chat/sessions`, { method: "POST" });
|
||||||
let data = await response.json();
|
let data = await response.json();
|
||||||
conversationID = data.conversation_id;
|
conversationID = data.conversation_id;
|
||||||
chat_body.dataset.conversationId = conversationID;
|
chatBody.dataset.conversationId = conversationID;
|
||||||
refreshChatSessionsPanel();
|
await refreshChatSessionsPanel();
|
||||||
}
|
}
|
||||||
|
|
||||||
let new_response = document.createElement("div");
|
let newResponseEl = document.createElement("div");
|
||||||
new_response.classList.add("chat-message", "khoj");
|
newResponseEl.classList.add("chat-message", "khoj");
|
||||||
new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
|
newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
|
||||||
chat_body.appendChild(new_response);
|
chatBody.appendChild(newResponseEl);
|
||||||
|
|
||||||
let newResponseText = document.createElement("div");
|
let newResponseTextEl = document.createElement("div");
|
||||||
newResponseText.classList.add("chat-message-text", "khoj");
|
newResponseTextEl.classList.add("chat-message-text", "khoj");
|
||||||
new_response.appendChild(newResponseText);
|
newResponseEl.appendChild(newResponseTextEl);
|
||||||
|
|
||||||
// Temporary status message to indicate that Khoj is thinking
|
// Temporary status message to indicate that Khoj is thinking
|
||||||
let loadingEllipsis = createLoadingEllipse();
|
let loadingEllipsis = createLoadingEllipse();
|
||||||
|
|
||||||
newResponseText.appendChild(loadingEllipsis);
|
newResponseTextEl.appendChild(loadingEllipsis);
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
|
||||||
let chatTooltip = document.getElementById("chat-tooltip");
|
let chatTooltip = document.getElementById("chat-tooltip");
|
||||||
|
@ -657,65 +650,38 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
let chatInput = document.getElementById("chat-input");
|
let chatInput = document.getElementById("chat-input");
|
||||||
chatInput.classList.remove("option-enabled");
|
chatInput.classList.remove("option-enabled");
|
||||||
|
|
||||||
// Generate backend API URL to execute query
|
// Setup chat message state
|
||||||
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`;
|
chatMessageState = {
|
||||||
|
newResponseTextEl,
|
||||||
|
newResponseEl,
|
||||||
|
loadingEllipsis,
|
||||||
|
references: {},
|
||||||
|
rawResponse: "",
|
||||||
|
rawQuery: query,
|
||||||
|
isVoice: isVoice,
|
||||||
|
}
|
||||||
|
|
||||||
// Call specified Khoj API
|
// Call Khoj chat API
|
||||||
let response = await fetch(url);
|
let chatApi = `/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationID}&stream=true&client=web`;
|
||||||
let rawResponse = "";
|
chatApi += (!!region && !!city && !!countryName && !!timezone)
|
||||||
let references = null;
|
? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`
|
||||||
const contentType = response.headers.get("content-type");
|
: '';
|
||||||
|
|
||||||
|
const response = await fetch(chatApi);
|
||||||
|
|
||||||
if (contentType === "application/json") {
|
|
||||||
// Handle JSON response
|
|
||||||
try {
|
try {
|
||||||
const responseAsJson = await response.json();
|
if (!response.ok) throw new Error(response.statusText);
|
||||||
if (responseAsJson.image || responseAsJson.detail) {
|
if (!response.body) throw new Error("Response body is empty");
|
||||||
({rawResponse, references } = handleImageResponse(responseAsJson, rawResponse));
|
// Stream and render chat response
|
||||||
} else {
|
await readChatStream(response);
|
||||||
rawResponse = responseAsJson.response;
|
} catch (err) {
|
||||||
|
console.error(`Khoj chat response failed with\n${err}`);
|
||||||
|
if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis)
|
||||||
|
chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
|
||||||
|
let errorMsg = "Sorry, unable to get response from Khoj backend ❤️🩹. Retry or contact developers for help at <a href=mailto:'team@khoj.dev'>team@khoj.dev</a> or <a href='https://discord.gg/BDgyabRM6e'>on Discord</a>";
|
||||||
|
newResponseTextEl.innerHTML = errorMsg;
|
||||||
}
|
}
|
||||||
} catch (error) {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
rawResponse += chunk;
|
|
||||||
} finally {
|
|
||||||
addMessageToChatBody(rawResponse, newResponseText, references);
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Handle streamed response of type text/event-stream or text/plain
|
|
||||||
const reader = response.body.getReader();
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
let references = {};
|
|
||||||
|
|
||||||
readStream();
|
|
||||||
|
|
||||||
function readStream() {
|
|
||||||
reader.read().then(({ done, value }) => {
|
|
||||||
if (done) {
|
|
||||||
// Append any references after all the data has been streamed
|
|
||||||
finalizeChatBodyResponse(references, newResponseText);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode message chunk from stream
|
|
||||||
const chunk = decoder.decode(value, { stream: true });
|
|
||||||
|
|
||||||
if (chunk.includes("### compiled references:")) {
|
|
||||||
({ rawResponse, references } = handleCompiledReferences(newResponseText, chunk, references, rawResponse));
|
|
||||||
readStream();
|
|
||||||
} else {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
rawResponse += chunk;
|
|
||||||
handleStreamResponse(newResponseText, rawResponse, query, loadingEllipsis);
|
|
||||||
readStream();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Scroll to bottom of chat window as chat response is streamed
|
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
function createLoadingEllipse() {
|
function createLoadingEllipse() {
|
||||||
// Temporary status message to indicate that Khoj is thinking
|
// Temporary status message to indicate that Khoj is thinking
|
||||||
|
@ -743,32 +709,22 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleStreamResponse(newResponseElement, rawResponse, rawQuery, loadingEllipsis, replace=true) {
|
function handleStreamResponse(newResponseElement, rawResponse, rawQuery, loadingEllipsis, replace=true) {
|
||||||
if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) {
|
if (!newResponseElement) return;
|
||||||
|
// Remove loading ellipsis if it exists
|
||||||
|
if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis)
|
||||||
newResponseElement.removeChild(loadingEllipsis);
|
newResponseElement.removeChild(loadingEllipsis);
|
||||||
}
|
// Clear the response element if replace is true
|
||||||
if (replace) {
|
if (replace) newResponseElement.innerHTML = "";
|
||||||
newResponseElement.innerHTML = "";
|
|
||||||
}
|
// Append response to the response element
|
||||||
newResponseElement.appendChild(formatHTMLMessage(rawResponse, false, replace, rawQuery));
|
newResponseElement.appendChild(formatHTMLMessage(rawResponse, false, replace, rawQuery));
|
||||||
|
|
||||||
|
// Append loading ellipsis if it exists
|
||||||
|
if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis);
|
||||||
|
// Scroll to bottom of chat view
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleCompiledReferences(rawResponseElement, chunk, references, rawResponse) {
|
|
||||||
const additionalResponse = chunk.split("### compiled references:")[0];
|
|
||||||
rawResponse += additionalResponse;
|
|
||||||
rawResponseElement.innerHTML = "";
|
|
||||||
rawResponseElement.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
const rawReference = chunk.split("### compiled references:")[1];
|
|
||||||
const rawReferenceAsJson = JSON.parse(rawReference);
|
|
||||||
if (rawReferenceAsJson instanceof Array) {
|
|
||||||
references["notes"] = rawReferenceAsJson;
|
|
||||||
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
|
||||||
references["online"] = rawReferenceAsJson;
|
|
||||||
}
|
|
||||||
return { rawResponse, references };
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleImageResponse(imageJson, rawResponse) {
|
function handleImageResponse(imageJson, rawResponse) {
|
||||||
if (imageJson.image) {
|
if (imageJson.image) {
|
||||||
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
|
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
|
||||||
|
@ -785,35 +741,139 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let references = {};
|
|
||||||
if (imageJson.context && imageJson.context.length > 0) {
|
|
||||||
const rawReferenceAsJson = imageJson.context;
|
|
||||||
if (rawReferenceAsJson instanceof Array) {
|
|
||||||
references["notes"] = rawReferenceAsJson;
|
|
||||||
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
|
||||||
references["online"] = rawReferenceAsJson;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (imageJson.detail) {
|
|
||||||
// If response has detail field, response is an error message.
|
// If response has detail field, response is an error message.
|
||||||
rawResponse += imageJson.detail;
|
if (imageJson.detail) rawResponse += imageJson.detail;
|
||||||
}
|
|
||||||
return { rawResponse, references };
|
|
||||||
}
|
|
||||||
|
|
||||||
function addMessageToChatBody(rawResponse, newResponseElement, references) {
|
return rawResponse;
|
||||||
newResponseElement.innerHTML = "";
|
|
||||||
newResponseElement.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
finalizeChatBodyResponse(references, newResponseElement);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function finalizeChatBodyResponse(references, newResponseElement) {
|
function finalizeChatBodyResponse(references, newResponseElement) {
|
||||||
if (references != null && Object.keys(references).length > 0) {
|
if (!!newResponseElement && references != null && Object.keys(references).length > 0) {
|
||||||
newResponseElement.appendChild(createReferenceSection(references));
|
newResponseElement.appendChild(createReferenceSection(references));
|
||||||
}
|
}
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
document.getElementById("chat-input").removeAttribute("disabled");
|
document.getElementById("chat-input")?.removeAttribute("disabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
function convertMessageChunkToJson(rawChunk) {
|
||||||
|
// Split the chunk into lines
|
||||||
|
console.debug("Raw Event:", rawChunk);
|
||||||
|
if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) {
|
||||||
|
try {
|
||||||
|
let jsonChunk = JSON.parse(rawChunk);
|
||||||
|
if (!jsonChunk.type)
|
||||||
|
jsonChunk = {type: 'message', data: jsonChunk};
|
||||||
|
return jsonChunk;
|
||||||
|
} catch (e) {
|
||||||
|
return {type: 'message', data: rawChunk};
|
||||||
|
}
|
||||||
|
} else if (rawChunk.length > 0) {
|
||||||
|
return {type: 'message', data: rawChunk};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function processMessageChunk(rawChunk) {
|
||||||
|
const chunk = convertMessageChunkToJson(rawChunk);
|
||||||
|
console.debug("Json Event:", chunk);
|
||||||
|
if (!chunk || !chunk.type) return;
|
||||||
|
if (chunk.type ==='status') {
|
||||||
|
console.log(`status: ${chunk.data}`);
|
||||||
|
const statusMessage = chunk.data;
|
||||||
|
handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, false);
|
||||||
|
} else if (chunk.type === 'start_llm_response') {
|
||||||
|
console.log("Started streaming", new Date());
|
||||||
|
} else if (chunk.type === 'end_llm_response') {
|
||||||
|
console.log("Stopped streaming", new Date());
|
||||||
|
|
||||||
|
// Automatically respond with voice if the subscribed user has sent voice message
|
||||||
|
if (chatMessageState.isVoice && "{{ is_active }}" == "True")
|
||||||
|
textToSpeech(chatMessageState.rawResponse);
|
||||||
|
|
||||||
|
// Append any references after all the data has been streamed
|
||||||
|
finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl);
|
||||||
|
|
||||||
|
const liveQuery = chatMessageState.rawQuery;
|
||||||
|
// Reset variables
|
||||||
|
chatMessageState = {
|
||||||
|
newResponseTextEl: null,
|
||||||
|
newResponseEl: null,
|
||||||
|
loadingEllipsis: null,
|
||||||
|
references: {},
|
||||||
|
rawResponse: "",
|
||||||
|
rawQuery: liveQuery,
|
||||||
|
isVoice: false,
|
||||||
|
}
|
||||||
|
} else if (chunk.type === "references") {
|
||||||
|
chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext};
|
||||||
|
} else if (chunk.type === 'message') {
|
||||||
|
const chunkData = chunk.data;
|
||||||
|
if (typeof chunkData === 'object' && chunkData !== null) {
|
||||||
|
// If chunkData is already a JSON object
|
||||||
|
handleJsonResponse(chunkData);
|
||||||
|
} else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) {
|
||||||
|
// Try process chunk data as if it is a JSON object
|
||||||
|
try {
|
||||||
|
const jsonData = JSON.parse(chunkData.trim());
|
||||||
|
handleJsonResponse(jsonData);
|
||||||
|
} catch (e) {
|
||||||
|
chatMessageState.rawResponse += chunkData;
|
||||||
|
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
chatMessageState.rawResponse += chunkData;
|
||||||
|
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleJsonResponse(jsonData) {
|
||||||
|
if (jsonData.image || jsonData.detail) {
|
||||||
|
chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse);
|
||||||
|
} else if (jsonData.response) {
|
||||||
|
chatMessageState.rawResponse = jsonData.response;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chatMessageState.newResponseTextEl) {
|
||||||
|
chatMessageState.newResponseTextEl.innerHTML = "";
|
||||||
|
chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function readChatStream(response) {
|
||||||
|
if (!response.body) return;
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
const eventDelimiter = '␃🔚␗';
|
||||||
|
let buffer = '';
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const { value, done } = await reader.read();
|
||||||
|
// If the stream is done
|
||||||
|
if (done) {
|
||||||
|
// Process the last chunk
|
||||||
|
processMessageChunk(buffer);
|
||||||
|
buffer = '';
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read chunk from stream and append it to the buffer
|
||||||
|
const chunk = decoder.decode(value, { stream: true });
|
||||||
|
console.debug("Raw Chunk:", chunk)
|
||||||
|
// Start buffering chunks until complete event is received
|
||||||
|
buffer += chunk;
|
||||||
|
|
||||||
|
// Once the buffer contains a complete event
|
||||||
|
let newEventIndex;
|
||||||
|
while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
|
||||||
|
// Extract the event from the buffer
|
||||||
|
const event = buffer.slice(0, newEventIndex);
|
||||||
|
buffer = buffer.slice(newEventIndex + eventDelimiter.length);
|
||||||
|
|
||||||
|
// Process the event
|
||||||
|
if (event) processMessageChunk(event);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function incrementalChat(event) {
|
function incrementalChat(event) {
|
||||||
|
@ -1069,17 +1129,13 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
|
|
||||||
window.onload = loadChat;
|
window.onload = loadChat;
|
||||||
|
|
||||||
function setupWebSocket(isVoice=false) {
|
function initMessageState(isVoice=false) {
|
||||||
let chatBody = document.getElementById("chat-body");
|
|
||||||
let wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
|
||||||
let webSocketUrl = `${wsProtocol}//${window.location.host}/api/chat/ws`;
|
|
||||||
|
|
||||||
if (waitingForLocation) {
|
if (waitingForLocation) {
|
||||||
console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available.");
|
console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
websocketState = {
|
chatMessageState = {
|
||||||
newResponseTextEl: null,
|
newResponseTextEl: null,
|
||||||
newResponseEl: null,
|
newResponseEl: null,
|
||||||
loadingEllipsis: null,
|
loadingEllipsis: null,
|
||||||
|
@ -1088,174 +1144,8 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
rawQuery: "",
|
rawQuery: "",
|
||||||
isVoice: isVoice,
|
isVoice: isVoice,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (chatBody.dataset.conversationId) {
|
|
||||||
webSocketUrl += `?conversation_id=${chatBody.dataset.conversationId}`;
|
|
||||||
webSocketUrl += (!!region && !!city && !!countryName) && !!timezone ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : '';
|
|
||||||
|
|
||||||
websocket = new WebSocket(webSocketUrl);
|
|
||||||
websocket.onmessage = function(event) {
|
|
||||||
|
|
||||||
// Get the last element in the chat-body
|
|
||||||
let chunk = event.data;
|
|
||||||
if (chunk == "start_llm_response") {
|
|
||||||
console.log("Started streaming", new Date());
|
|
||||||
} else if (chunk == "end_llm_response") {
|
|
||||||
console.log("Stopped streaming", new Date());
|
|
||||||
|
|
||||||
// Automatically respond with voice if the subscribed user has sent voice message
|
|
||||||
if (websocketState.isVoice && "{{ is_active }}" == "True")
|
|
||||||
textToSpeech(websocketState.rawResponse);
|
|
||||||
|
|
||||||
// Append any references after all the data has been streamed
|
|
||||||
finalizeChatBodyResponse(websocketState.references, websocketState.newResponseTextEl);
|
|
||||||
|
|
||||||
const liveQuery = websocketState.rawQuery;
|
|
||||||
// Reset variables
|
|
||||||
websocketState = {
|
|
||||||
newResponseTextEl: null,
|
|
||||||
newResponseEl: null,
|
|
||||||
loadingEllipsis: null,
|
|
||||||
references: {},
|
|
||||||
rawResponse: "",
|
|
||||||
rawQuery: liveQuery,
|
|
||||||
isVoice: false,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
try {
|
|
||||||
if (chunk.includes("application/json"))
|
|
||||||
{
|
|
||||||
chunk = JSON.parse(chunk);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
// If the chunk is not a JSON object, continue.
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const contentType = chunk["content-type"]
|
|
||||||
|
|
||||||
if (contentType === "application/json") {
|
|
||||||
// Handle JSON response
|
|
||||||
try {
|
|
||||||
if (chunk.image || chunk.detail) {
|
|
||||||
({rawResponse, references } = handleImageResponse(chunk, websocketState.rawResponse));
|
|
||||||
websocketState.rawResponse = rawResponse;
|
|
||||||
websocketState.references = references;
|
|
||||||
} else if (chunk.type == "status") {
|
|
||||||
handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, null, false);
|
|
||||||
} else if (chunk.type == "rate_limit") {
|
|
||||||
handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, websocketState.loadingEllipsis, true);
|
|
||||||
} else {
|
|
||||||
rawResponse = chunk.response;
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
websocketState.rawResponse += chunk;
|
|
||||||
} finally {
|
|
||||||
if (chunk.type != "status" && chunk.type != "rate_limit") {
|
|
||||||
addMessageToChatBody(websocketState.rawResponse, websocketState.newResponseTextEl, websocketState.references);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
|
|
||||||
// Handle streamed response of type text/event-stream or text/plain
|
|
||||||
if (chunk && chunk.includes("### compiled references:")) {
|
|
||||||
({ rawResponse, references } = handleCompiledReferences(websocketState.newResponseTextEl, chunk, websocketState.references, websocketState.rawResponse));
|
|
||||||
websocketState.rawResponse = rawResponse;
|
|
||||||
websocketState.references = references;
|
|
||||||
} else {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
websocketState.rawResponse += chunk;
|
|
||||||
if (websocketState.newResponseTextEl) {
|
|
||||||
handleStreamResponse(websocketState.newResponseTextEl, websocketState.rawResponse, websocketState.rawQuery, websocketState.loadingEllipsis);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scroll to bottom of chat window as chat response is streamed
|
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
websocket.onclose = function(event) {
|
|
||||||
websocket = null;
|
|
||||||
console.log("WebSocket is closed now.");
|
|
||||||
let setupWebSocketButton = document.createElement("button");
|
|
||||||
setupWebSocketButton.textContent = "Reconnect to Server";
|
|
||||||
setupWebSocketButton.onclick = setupWebSocket;
|
|
||||||
let statusDotIcon = document.getElementById("connection-status-icon");
|
|
||||||
statusDotIcon.style.backgroundColor = "red";
|
|
||||||
let statusDotText = document.getElementById("connection-status-text");
|
|
||||||
statusDotText.innerHTML = "";
|
|
||||||
statusDotText.style.marginTop = "5px";
|
|
||||||
statusDotText.appendChild(setupWebSocketButton);
|
|
||||||
}
|
|
||||||
websocket.onerror = function(event) {
|
|
||||||
console.log("WebSocket error observed:", event);
|
|
||||||
}
|
|
||||||
|
|
||||||
websocket.onopen = function(event) {
|
|
||||||
console.log("WebSocket is open now.")
|
|
||||||
let statusDotIcon = document.getElementById("connection-status-icon");
|
|
||||||
statusDotIcon.style.backgroundColor = "green";
|
|
||||||
let statusDotText = document.getElementById("connection-status-text");
|
|
||||||
statusDotText.textContent = "Connected to Server";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function sendMessageViaWebSocket(isVoice=false) {
|
|
||||||
let chatBody = document.getElementById("chat-body");
|
|
||||||
|
|
||||||
var query = document.getElementById("chat-input").value.trim();
|
|
||||||
console.log(`Query: ${query}`);
|
|
||||||
|
|
||||||
if (userMessages.length >= 10) {
|
|
||||||
userMessages.shift();
|
|
||||||
}
|
|
||||||
userMessages.push(query);
|
|
||||||
resetUserMessageIndex();
|
|
||||||
|
|
||||||
// Add message by user to chat body
|
|
||||||
renderMessage(query, "you");
|
|
||||||
document.getElementById("chat-input").value = "";
|
|
||||||
autoResize();
|
|
||||||
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
|
||||||
|
|
||||||
let newResponseEl = document.createElement("div");
|
|
||||||
newResponseEl.classList.add("chat-message", "khoj");
|
|
||||||
newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
|
|
||||||
chatBody.appendChild(newResponseEl);
|
|
||||||
|
|
||||||
let newResponseTextEl = document.createElement("div");
|
|
||||||
newResponseTextEl.classList.add("chat-message-text", "khoj");
|
|
||||||
newResponseEl.appendChild(newResponseTextEl);
|
|
||||||
|
|
||||||
// Temporary status message to indicate that Khoj is thinking
|
|
||||||
let loadingEllipsis = createLoadingEllipse();
|
|
||||||
|
|
||||||
newResponseTextEl.appendChild(loadingEllipsis);
|
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
|
||||||
|
|
||||||
let chatTooltip = document.getElementById("chat-tooltip");
|
|
||||||
chatTooltip.style.display = "none";
|
|
||||||
|
|
||||||
let chatInput = document.getElementById("chat-input");
|
|
||||||
chatInput.classList.remove("option-enabled");
|
|
||||||
|
|
||||||
// Call specified Khoj API
|
|
||||||
websocket.send(query);
|
|
||||||
let rawResponse = "";
|
|
||||||
let references = {};
|
|
||||||
|
|
||||||
websocketState = {
|
|
||||||
newResponseTextEl,
|
|
||||||
newResponseEl,
|
|
||||||
loadingEllipsis,
|
|
||||||
references,
|
|
||||||
rawResponse,
|
|
||||||
rawQuery: query,
|
|
||||||
isVoice: isVoice,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var userMessages = [];
|
var userMessages = [];
|
||||||
var userMessageIndex = -1;
|
var userMessageIndex = -1;
|
||||||
function loadChat() {
|
function loadChat() {
|
||||||
|
@ -1265,7 +1155,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
let chatHistoryUrl = `/api/chat/history?client=web`;
|
let chatHistoryUrl = `/api/chat/history?client=web`;
|
||||||
if (chatBody.dataset.conversationId) {
|
if (chatBody.dataset.conversationId) {
|
||||||
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
|
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
|
||||||
setupWebSocket();
|
initMessageState();
|
||||||
loadFileFiltersFromConversation();
|
loadFileFiltersFromConversation();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1305,7 +1195,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
let chatBody = document.getElementById("chat-body");
|
let chatBody = document.getElementById("chat-body");
|
||||||
chatBody.dataset.conversationId = response.conversation_id;
|
chatBody.dataset.conversationId = response.conversation_id;
|
||||||
loadFileFiltersFromConversation();
|
loadFileFiltersFromConversation();
|
||||||
setupWebSocket();
|
initMessageState();
|
||||||
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
|
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
|
||||||
|
|
||||||
let agentMetadata = response.agent;
|
let agentMetadata = response.agent;
|
||||||
|
|
|
@ -62,10 +62,6 @@ class ThreadedGenerator:
|
||||||
self.queue.put(data)
|
self.queue.put(data)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.compiled_references and len(self.compiled_references) > 0:
|
|
||||||
self.queue.put(f"### compiled references:{json.dumps(self.compiled_references)}")
|
|
||||||
if self.online_results and len(self.online_results) > 0:
|
|
||||||
self.queue.put(f"### compiled references:{json.dumps(self.online_results)}")
|
|
||||||
self.queue.put(StopIteration)
|
self.queue.put(StopIteration)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ from bs4 import BeautifulSoup
|
||||||
from markdownify import markdownify
|
from markdownify import markdownify
|
||||||
|
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
|
ChatEvent,
|
||||||
extract_relevant_info,
|
extract_relevant_info,
|
||||||
generate_online_subqueries,
|
generate_online_subqueries,
|
||||||
infer_webpage_urls,
|
infer_webpage_urls,
|
||||||
|
@ -56,7 +57,8 @@ async def search_online(
|
||||||
query += " ".join(custom_filters)
|
query += " ".join(custom_filters)
|
||||||
if not is_internet_connected():
|
if not is_internet_connected():
|
||||||
logger.warn("Cannot search online as not connected to internet")
|
logger.warn("Cannot search online as not connected to internet")
|
||||||
return {}
|
yield {}
|
||||||
|
return
|
||||||
|
|
||||||
# Breakdown the query into subqueries to get the correct answer
|
# Breakdown the query into subqueries to get the correct answer
|
||||||
subqueries = await generate_online_subqueries(query, conversation_history, location)
|
subqueries = await generate_online_subqueries(query, conversation_history, location)
|
||||||
|
@ -66,7 +68,8 @@ async def search_online(
|
||||||
logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
|
logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
|
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
|
||||||
await send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}")
|
async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"):
|
||||||
|
yield {ChatEvent.STATUS: event}
|
||||||
|
|
||||||
with timer(f"Internet searches for {list(subqueries)} took", logger):
|
with timer(f"Internet searches for {list(subqueries)} took", logger):
|
||||||
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
|
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
|
||||||
|
@ -89,7 +92,8 @@ async def search_online(
|
||||||
logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}")
|
logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}")
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
|
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
|
||||||
await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}")
|
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
|
||||||
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
|
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
@ -98,7 +102,7 @@ async def search_online(
|
||||||
if webpage_extract is not None:
|
if webpage_extract is not None:
|
||||||
response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract}
|
response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract}
|
||||||
|
|
||||||
return response_dict
|
yield response_dict
|
||||||
|
|
||||||
|
|
||||||
async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
|
async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
|
||||||
|
@ -127,13 +131,15 @@ async def read_webpages(
|
||||||
"Infer web pages to read from the query and extract relevant information from them"
|
"Infer web pages to read from the query and extract relevant information from them"
|
||||||
logger.info(f"Inferring web pages to read")
|
logger.info(f"Inferring web pages to read")
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
await send_status_func(f"**🧐 Inferring web pages to read**")
|
async for event in send_status_func(f"**🧐 Inferring web pages to read**"):
|
||||||
|
yield {ChatEvent.STATUS: event}
|
||||||
urls = await infer_webpage_urls(query, conversation_history, location)
|
urls = await infer_webpage_urls(query, conversation_history, location)
|
||||||
|
|
||||||
logger.info(f"Reading web pages at: {urls}")
|
logger.info(f"Reading web pages at: {urls}")
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||||
await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}")
|
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
|
||||||
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [read_webpage_and_extract_content(query, url) for url in urls]
|
tasks = [read_webpage_and_extract_content(query, url) for url in urls]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
@ -141,7 +147,7 @@ async def read_webpages(
|
||||||
response[query]["webpages"] = [
|
response[query]["webpages"] = [
|
||||||
{"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None
|
{"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None
|
||||||
]
|
]
|
||||||
return response
|
yield response
|
||||||
|
|
||||||
|
|
||||||
async def read_webpage_and_extract_content(
|
async def read_webpage_and_extract_content(
|
||||||
|
|
|
@ -6,7 +6,6 @@ import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from random import random
|
|
||||||
from typing import Any, Callable, List, Optional, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
import cron_descriptor
|
import cron_descriptor
|
||||||
|
@ -37,6 +36,7 @@ from khoj.processor.conversation.openai.gpt import extract_questions
|
||||||
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ApiUserRateLimiter,
|
ApiUserRateLimiter,
|
||||||
|
ChatEvent,
|
||||||
CommonQueryParams,
|
CommonQueryParams,
|
||||||
ConversationCommandRateLimiter,
|
ConversationCommandRateLimiter,
|
||||||
acreate_title_from_query,
|
acreate_title_from_query,
|
||||||
|
@ -298,11 +298,13 @@ async def extract_references_and_questions(
|
||||||
not ConversationCommand.Notes in conversation_commands
|
not ConversationCommand.Notes in conversation_commands
|
||||||
and not ConversationCommand.Default in conversation_commands
|
and not ConversationCommand.Default in conversation_commands
|
||||||
):
|
):
|
||||||
return compiled_references, inferred_queries, q
|
yield compiled_references, inferred_queries, q
|
||||||
|
return
|
||||||
|
|
||||||
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
|
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
|
||||||
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
|
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
|
||||||
return compiled_references, inferred_queries, q
|
yield compiled_references, inferred_queries, q
|
||||||
|
return
|
||||||
|
|
||||||
# Extract filter terms from user message
|
# Extract filter terms from user message
|
||||||
defiltered_query = q
|
defiltered_query = q
|
||||||
|
@ -313,7 +315,8 @@ async def extract_references_and_questions(
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
logger.error(f"Conversation with id {conversation_id} not found.")
|
logger.error(f"Conversation with id {conversation_id} not found.")
|
||||||
return compiled_references, inferred_queries, defiltered_query
|
yield compiled_references, inferred_queries, defiltered_query
|
||||||
|
return
|
||||||
|
|
||||||
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
|
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
|
||||||
using_offline_chat = False
|
using_offline_chat = False
|
||||||
|
@ -373,7 +376,8 @@ async def extract_references_and_questions(
|
||||||
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
|
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
|
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
|
||||||
await send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}")
|
async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"):
|
||||||
|
yield {ChatEvent.STATUS: event}
|
||||||
for query in inferred_queries:
|
for query in inferred_queries:
|
||||||
n_items = min(n, 3) if using_offline_chat else n
|
n_items = min(n, 3) if using_offline_chat else n
|
||||||
search_results.extend(
|
search_results.extend(
|
||||||
|
@ -392,7 +396,7 @@ async def extract_references_and_questions(
|
||||||
{"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results
|
{"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results
|
||||||
]
|
]
|
||||||
|
|
||||||
return compiled_references, inferred_queries, defiltered_query
|
yield compiled_references, inferred_queries, defiltered_query
|
||||||
|
|
||||||
|
|
||||||
@api.get("/health", response_class=Response)
|
@api.get("/health", response_class=Response)
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
from fastapi.responses import Response, StreamingResponse
|
from fastapi.responses import Response, StreamingResponse
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
from starlette.websockets import WebSocketDisconnect
|
|
||||||
from websockets import ConnectionClosedOK
|
|
||||||
|
|
||||||
from khoj.app.settings import ALLOWED_HOSTS
|
from khoj.app.settings import ALLOWED_HOSTS
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
|
@ -23,19 +23,15 @@ from khoj.database.adapters import (
|
||||||
aget_user_name,
|
aget_user_name,
|
||||||
)
|
)
|
||||||
from khoj.database.models import KhojUser
|
from khoj.database.models import KhojUser
|
||||||
from khoj.processor.conversation.prompts import (
|
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||||
help_message,
|
|
||||||
no_entries_found,
|
|
||||||
no_notes_found,
|
|
||||||
)
|
|
||||||
from khoj.processor.conversation.utils import save_to_conversation_log
|
from khoj.processor.conversation.utils import save_to_conversation_log
|
||||||
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
||||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||||
from khoj.routers.api import extract_references_and_questions
|
from khoj.routers.api import extract_references_and_questions
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ApiUserRateLimiter,
|
ApiUserRateLimiter,
|
||||||
|
ChatEvent,
|
||||||
CommonQueryParams,
|
CommonQueryParams,
|
||||||
CommonQueryParamsClass,
|
|
||||||
ConversationCommandRateLimiter,
|
ConversationCommandRateLimiter,
|
||||||
agenerate_chat_response,
|
agenerate_chat_response,
|
||||||
aget_relevant_information_sources,
|
aget_relevant_information_sources,
|
||||||
|
@ -526,141 +522,142 @@ async def set_conversation_title(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@api_chat.websocket("/ws")
|
@api_chat.get("")
|
||||||
async def websocket_endpoint(
|
async def chat(
|
||||||
websocket: WebSocket,
|
request: Request,
|
||||||
conversation_id: int,
|
common: CommonQueryParams,
|
||||||
|
q: str,
|
||||||
|
n: int = 7,
|
||||||
|
d: float = 0.18,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
conversation_id: Optional[int] = None,
|
||||||
city: Optional[str] = None,
|
city: Optional[str] = None,
|
||||||
region: Optional[str] = None,
|
region: Optional[str] = None,
|
||||||
country: Optional[str] = None,
|
country: Optional[str] = None,
|
||||||
timezone: Optional[str] = None,
|
timezone: Optional[str] = None,
|
||||||
|
rate_limiter_per_minute=Depends(
|
||||||
|
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
||||||
|
),
|
||||||
|
rate_limiter_per_day=Depends(
|
||||||
|
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||||
|
),
|
||||||
):
|
):
|
||||||
|
async def event_generator(q: str):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
ttft = None
|
||||||
|
chat_metadata: dict = {}
|
||||||
connection_alive = True
|
connection_alive = True
|
||||||
|
user: KhojUser = request.user.object
|
||||||
|
event_delimiter = "␃🔚␗"
|
||||||
|
q = unquote(q)
|
||||||
|
|
||||||
async def send_status_update(message: str):
|
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||||
nonlocal connection_alive
|
nonlocal connection_alive, ttft
|
||||||
if not connection_alive:
|
if not connection_alive or await request.is_disconnected():
|
||||||
return
|
|
||||||
|
|
||||||
status_packet = {
|
|
||||||
"type": "status",
|
|
||||||
"message": message,
|
|
||||||
"content-type": "application/json",
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
await websocket.send_text(json.dumps(status_packet))
|
|
||||||
except ConnectionClosedOK:
|
|
||||||
connection_alive = False
|
connection_alive = False
|
||||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
logger.warn(f"User {user} disconnected from {common.client} client")
|
||||||
|
|
||||||
async def send_complete_llm_response(llm_response: str):
|
|
||||||
nonlocal connection_alive
|
|
||||||
if not connection_alive:
|
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
await websocket.send_text("start_llm_response")
|
if event_type == ChatEvent.END_LLM_RESPONSE:
|
||||||
await websocket.send_text(llm_response)
|
collect_telemetry()
|
||||||
await websocket.send_text("end_llm_response")
|
if event_type == ChatEvent.START_LLM_RESPONSE:
|
||||||
except ConnectionClosedOK:
|
ttft = time.perf_counter() - start_time
|
||||||
|
if event_type == ChatEvent.MESSAGE:
|
||||||
|
yield data
|
||||||
|
elif event_type == ChatEvent.REFERENCES or stream:
|
||||||
|
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
|
||||||
|
except asyncio.CancelledError as e:
|
||||||
connection_alive = False
|
connection_alive = False
|
||||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
logger.warn(f"User {user} disconnected from {common.client} client: {e}")
|
||||||
|
|
||||||
async def send_message(message: str):
|
|
||||||
nonlocal connection_alive
|
|
||||||
if not connection_alive:
|
|
||||||
return
|
return
|
||||||
try:
|
except Exception as e:
|
||||||
await websocket.send_text(message)
|
|
||||||
except ConnectionClosedOK:
|
|
||||||
connection_alive = False
|
connection_alive = False
|
||||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
|
||||||
|
|
||||||
async def send_rate_limit_message(message: str):
|
|
||||||
nonlocal connection_alive
|
|
||||||
if not connection_alive:
|
|
||||||
return
|
return
|
||||||
|
finally:
|
||||||
|
if stream:
|
||||||
|
yield event_delimiter
|
||||||
|
|
||||||
status_packet = {
|
async def send_llm_response(response: str):
|
||||||
"type": "rate_limit",
|
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||||
"message": message,
|
yield result
|
||||||
"content-type": "application/json",
|
async for result in send_event(ChatEvent.MESSAGE, response):
|
||||||
}
|
yield result
|
||||||
try:
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||||
await websocket.send_text(json.dumps(status_packet))
|
yield result
|
||||||
except ConnectionClosedOK:
|
|
||||||
connection_alive = False
|
|
||||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
|
||||||
|
|
||||||
user: KhojUser = websocket.user.object
|
def collect_telemetry():
|
||||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
# Gather chat response telemetry
|
||||||
user, client_application=websocket.user.client_app, conversation_id=conversation_id
|
nonlocal chat_metadata
|
||||||
|
latency = time.perf_counter() - start_time
|
||||||
|
cmd_set = set([cmd.value for cmd in conversation_commands])
|
||||||
|
chat_metadata = chat_metadata or {}
|
||||||
|
chat_metadata["conversation_command"] = cmd_set
|
||||||
|
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
|
||||||
|
chat_metadata["latency"] = f"{latency:.3f}"
|
||||||
|
chat_metadata["ttft_latency"] = f"{ttft:.3f}"
|
||||||
|
|
||||||
|
logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
|
||||||
|
logger.info(f"Chat response total time: {latency:.3f} seconds")
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="chat",
|
||||||
|
client=request.user.client_app,
|
||||||
|
user_agent=request.headers.get("user-agent"),
|
||||||
|
host=request.headers.get("host"),
|
||||||
|
metadata=chat_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||||
|
user, client_application=request.user.client_app, conversation_id=conversation_id, title=title
|
||||||
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
)
|
||||||
|
if not conversation:
|
||||||
|
async for result in send_llm_response(f"Conversation {conversation_id} not found"):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
|
||||||
await is_ready_to_chat(user)
|
await is_ready_to_chat(user)
|
||||||
|
|
||||||
user_name = await aget_user_name(user)
|
user_name = await aget_user_name(user)
|
||||||
|
|
||||||
location = None
|
location = None
|
||||||
|
|
||||||
if city or region or country:
|
if city or region or country:
|
||||||
location = LocationData(city=city, region=region, country=country)
|
location = LocationData(city=city, region=region, country=country)
|
||||||
|
|
||||||
await websocket.accept()
|
|
||||||
while connection_alive:
|
|
||||||
try:
|
|
||||||
if conversation:
|
|
||||||
await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"])
|
|
||||||
q = await websocket.receive_text()
|
|
||||||
|
|
||||||
# Refresh these because the connection to the database might have been closed
|
|
||||||
await conversation.arefresh_from_db()
|
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
logger.debug(f"User {user} disconnected web socket")
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
await sync_to_async(hourly_limiter)(websocket)
|
|
||||||
await sync_to_async(daily_limiter)(websocket)
|
|
||||||
except HTTPException as e:
|
|
||||||
await send_rate_limit_message(e.detail)
|
|
||||||
break
|
|
||||||
|
|
||||||
if is_query_empty(q):
|
if is_query_empty(q):
|
||||||
await send_message("start_llm_response")
|
async for result in send_llm_response("Please ask your query to get started."):
|
||||||
await send_message(
|
yield result
|
||||||
"It seems like your query is incomplete. Could you please provide more details or specify what you need help with?"
|
return
|
||||||
)
|
|
||||||
await send_message("end_llm_response")
|
|
||||||
continue
|
|
||||||
|
|
||||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||||
|
|
||||||
await send_status_update(f"**👀 Understanding Query**: {q}")
|
async for result in send_event(ChatEvent.STATUS, f"**👀 Understanding Query**: {q}"):
|
||||||
|
yield result
|
||||||
|
|
||||||
meta_log = conversation.conversation_log
|
meta_log = conversation.conversation_log
|
||||||
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
||||||
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
|
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||||
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
|
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
|
||||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||||
await send_status_update(f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}")
|
async for result in send_event(
|
||||||
|
ChatEvent.STATUS, f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
|
||||||
|
):
|
||||||
|
yield result
|
||||||
|
|
||||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
|
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
|
||||||
await send_status_update(f"**🧑🏾💻 Decided Response Mode:** {mode.value}")
|
async for result in send_event(ChatEvent.STATUS, f"**🧑🏾💻 Decided Response Mode:** {mode.value}"):
|
||||||
|
yield result
|
||||||
if mode not in conversation_commands:
|
if mode not in conversation_commands:
|
||||||
conversation_commands.append(mode)
|
conversation_commands.append(mode)
|
||||||
|
|
||||||
for cmd in conversation_commands:
|
for cmd in conversation_commands:
|
||||||
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
|
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
||||||
q = q.replace(f"/{cmd.value}", "").strip()
|
q = q.replace(f"/{cmd.value}", "").strip()
|
||||||
|
|
||||||
|
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
|
||||||
file_filters = conversation.file_filters if conversation else []
|
file_filters = conversation.file_filters if conversation else []
|
||||||
# Skip trying to summarize if
|
# Skip trying to summarize if
|
||||||
if (
|
if (
|
||||||
|
@ -676,28 +673,37 @@ async def websocket_endpoint(
|
||||||
response_log = ""
|
response_log = ""
|
||||||
if len(file_filters) == 0:
|
if len(file_filters) == 0:
|
||||||
response_log = "No files selected for summarization. Please add files using the section on the left."
|
response_log = "No files selected for summarization. Please add files using the section on the left."
|
||||||
await send_complete_llm_response(response_log)
|
async for result in send_llm_response(response_log):
|
||||||
|
yield result
|
||||||
elif len(file_filters) > 1:
|
elif len(file_filters) > 1:
|
||||||
response_log = "Only one file can be selected for summarization."
|
response_log = "Only one file can be selected for summarization."
|
||||||
await send_complete_llm_response(response_log)
|
async for result in send_llm_response(response_log):
|
||||||
|
yield result
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||||
if len(file_object) == 0:
|
if len(file_object) == 0:
|
||||||
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
|
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
|
||||||
await send_complete_llm_response(response_log)
|
async for result in send_llm_response(response_log):
|
||||||
continue
|
yield result
|
||||||
|
return
|
||||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||||
if not q:
|
if not q:
|
||||||
q = "Create a general summary of the file"
|
q = "Create a general summary of the file"
|
||||||
await send_status_update(f"**🧑🏾💻 Constructing Summary Using:** {file_object[0].file_name}")
|
async for result in send_event(
|
||||||
|
ChatEvent.STATUS, f"**🧑🏾💻 Constructing Summary Using:** {file_object[0].file_name}"
|
||||||
|
):
|
||||||
|
yield result
|
||||||
|
|
||||||
response = await extract_relevant_summary(q, contextual_data)
|
response = await extract_relevant_summary(q, contextual_data)
|
||||||
response_log = str(response)
|
response_log = str(response)
|
||||||
await send_complete_llm_response(response_log)
|
async for result in send_llm_response(response_log):
|
||||||
|
yield result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
response_log = "Error summarizing file."
|
response_log = "Error summarizing file."
|
||||||
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
||||||
await send_complete_llm_response(response_log)
|
async for result in send_llm_response(response_log):
|
||||||
|
yield result
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
q,
|
q,
|
||||||
response_log,
|
response_log,
|
||||||
|
@ -705,16 +711,10 @@ async def websocket_endpoint(
|
||||||
meta_log,
|
meta_log,
|
||||||
user_message_time,
|
user_message_time,
|
||||||
intent_type="summarize",
|
intent_type="summarize",
|
||||||
client_application=websocket.user.client_app,
|
client_application=request.user.client_app,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
)
|
)
|
||||||
update_telemetry_state(
|
return
|
||||||
request=websocket,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="chat",
|
|
||||||
metadata={"conversation_command": conversation_commands[0].value},
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
custom_filters = []
|
custom_filters = []
|
||||||
if conversation_commands == [ConversationCommand.Help]:
|
if conversation_commands == [ConversationCommand.Help]:
|
||||||
|
@ -724,8 +724,9 @@ async def websocket_endpoint(
|
||||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||||
model_type = conversation_config.model_type
|
model_type = conversation_config.model_type
|
||||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||||
await send_complete_llm_response(formatted_help)
|
async for result in send_llm_response(formatted_help):
|
||||||
continue
|
yield result
|
||||||
|
return
|
||||||
# Adding specification to search online specifically on khoj.dev pages.
|
# Adding specification to search online specifically on khoj.dev pages.
|
||||||
custom_filters.append("site:khoj.dev")
|
custom_filters.append("site:khoj.dev")
|
||||||
conversation_commands.append(ConversationCommand.Online)
|
conversation_commands.append(ConversationCommand.Online)
|
||||||
|
@ -733,14 +734,14 @@ async def websocket_endpoint(
|
||||||
if ConversationCommand.Automation in conversation_commands:
|
if ConversationCommand.Automation in conversation_commands:
|
||||||
try:
|
try:
|
||||||
automation, crontime, query_to_run, subject = await create_automation(
|
automation, crontime, query_to_run, subject = await create_automation(
|
||||||
q, timezone, user, websocket.url, meta_log
|
q, timezone, user, request.url, meta_log
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||||
await send_complete_llm_response(
|
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
|
||||||
f"Unable to create automation. Ensure the automation doesn't already exist."
|
async for result in send_llm_response(error_message):
|
||||||
)
|
yield result
|
||||||
continue
|
return
|
||||||
|
|
||||||
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
|
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
|
@ -750,57 +751,78 @@ async def websocket_endpoint(
|
||||||
meta_log,
|
meta_log,
|
||||||
user_message_time,
|
user_message_time,
|
||||||
intent_type="automation",
|
intent_type="automation",
|
||||||
client_application=websocket.user.client_app,
|
client_application=request.user.client_app,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
inferred_queries=[query_to_run],
|
inferred_queries=[query_to_run],
|
||||||
automation_id=automation.id,
|
automation_id=automation.id,
|
||||||
)
|
)
|
||||||
common = CommonQueryParamsClass(
|
async for result in send_llm_response(llm_response):
|
||||||
client=websocket.user.client_app,
|
yield result
|
||||||
user_agent=websocket.headers.get("user-agent"),
|
return
|
||||||
host=websocket.headers.get("host"),
|
|
||||||
)
|
|
||||||
update_telemetry_state(
|
|
||||||
request=websocket,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="chat",
|
|
||||||
**common.__dict__,
|
|
||||||
)
|
|
||||||
await send_complete_llm_response(llm_response)
|
|
||||||
continue
|
|
||||||
|
|
||||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
# Gather Context
|
||||||
websocket, meta_log, q, 7, 0.18, conversation_id, conversation_commands, location, send_status_update
|
## Extract Document References
|
||||||
)
|
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||||
|
async for result in extract_references_and_questions(
|
||||||
|
request,
|
||||||
|
meta_log,
|
||||||
|
q,
|
||||||
|
(n or 7),
|
||||||
|
(d or 0.18),
|
||||||
|
conversation_id,
|
||||||
|
conversation_commands,
|
||||||
|
location,
|
||||||
|
partial(send_event, ChatEvent.STATUS),
|
||||||
|
):
|
||||||
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
|
else:
|
||||||
|
compiled_references.extend(result[0])
|
||||||
|
inferred_queries.extend(result[1])
|
||||||
|
defiltered_query = result[2]
|
||||||
|
|
||||||
if compiled_references:
|
if not is_none_or_empty(compiled_references):
|
||||||
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
||||||
await send_status_update(f"**📜 Found Relevant Notes**: {headings}")
|
async for result in send_event(ChatEvent.STATUS, f"**📜 Found Relevant Notes**: {headings}"):
|
||||||
|
yield result
|
||||||
|
|
||||||
online_results: Dict = dict()
|
online_results: Dict = dict()
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||||
await send_complete_llm_response(f"{no_entries_found.format()}")
|
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
||||||
continue
|
yield result
|
||||||
|
return
|
||||||
|
|
||||||
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
||||||
conversation_commands.remove(ConversationCommand.Notes)
|
conversation_commands.remove(ConversationCommand.Notes)
|
||||||
|
|
||||||
|
## Gather Online References
|
||||||
if ConversationCommand.Online in conversation_commands:
|
if ConversationCommand.Online in conversation_commands:
|
||||||
try:
|
try:
|
||||||
online_results = await search_online(
|
async for result in search_online(
|
||||||
defiltered_query, meta_log, location, send_status_update, custom_filters
|
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters
|
||||||
)
|
):
|
||||||
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
|
else:
|
||||||
|
online_results = result
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
|
error_message = f"Error searching online: {e}. Attempting to respond without online results"
|
||||||
await send_complete_llm_response(
|
logger.warning(error_message)
|
||||||
f"Error searching online: {e}. Attempting to respond without online results"
|
async for result in send_llm_response(error_message):
|
||||||
)
|
yield result
|
||||||
continue
|
return
|
||||||
|
|
||||||
|
## Gather Webpage References
|
||||||
if ConversationCommand.Webpage in conversation_commands:
|
if ConversationCommand.Webpage in conversation_commands:
|
||||||
try:
|
try:
|
||||||
direct_web_pages = await read_webpages(defiltered_query, meta_log, location, send_status_update)
|
async for result in read_webpages(
|
||||||
|
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS)
|
||||||
|
):
|
||||||
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
|
else:
|
||||||
|
direct_web_pages = result
|
||||||
webpages = []
|
webpages = []
|
||||||
for query in direct_web_pages:
|
for query in direct_web_pages:
|
||||||
if online_results.get(query):
|
if online_results.get(query):
|
||||||
|
@ -810,304 +832,52 @@ async def websocket_endpoint(
|
||||||
|
|
||||||
for webpage in direct_web_pages[query]["webpages"]:
|
for webpage in direct_web_pages[query]["webpages"]:
|
||||||
webpages.append(webpage["link"])
|
webpages.append(webpage["link"])
|
||||||
|
async for result in send_event(ChatEvent.STATUS, f"**📚 Read web pages**: {webpages}"):
|
||||||
await send_status_update(f"**📚 Read web pages**: {webpages}")
|
yield result
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
|
f"Error directly reading webpages: {e}. Attempting to respond without online results",
|
||||||
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## Send Gathered References
|
||||||
|
async for result in send_event(
|
||||||
|
ChatEvent.REFERENCES,
|
||||||
|
{
|
||||||
|
"inferredQueries": inferred_queries,
|
||||||
|
"context": compiled_references,
|
||||||
|
"onlineContext": online_results,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
# Generate Output
|
||||||
|
## Generate Image Output
|
||||||
if ConversationCommand.Image in conversation_commands:
|
if ConversationCommand.Image in conversation_commands:
|
||||||
update_telemetry_state(
|
async for result in text_to_image(
|
||||||
request=websocket,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="chat",
|
|
||||||
metadata={"conversation_command": conversation_commands[0].value},
|
|
||||||
)
|
|
||||||
image, status_code, improved_image_prompt, intent_type = await text_to_image(
|
|
||||||
q,
|
q,
|
||||||
user,
|
user,
|
||||||
meta_log,
|
meta_log,
|
||||||
location_data=location,
|
location_data=location,
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
send_status_func=send_status_update,
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
)
|
):
|
||||||
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
|
else:
|
||||||
|
image, status_code, improved_image_prompt, intent_type = result
|
||||||
|
|
||||||
if image is None or status_code != 200:
|
if image is None or status_code != 200:
|
||||||
content_obj = {
|
content_obj = {
|
||||||
"image": image,
|
"content-type": "application/json",
|
||||||
"intentType": intent_type,
|
"intentType": intent_type,
|
||||||
"detail": improved_image_prompt,
|
"detail": improved_image_prompt,
|
||||||
"content-type": "application/json",
|
"image": image,
|
||||||
}
|
}
|
||||||
await send_complete_llm_response(json.dumps(content_obj))
|
async for result in send_llm_response(json.dumps(content_obj)):
|
||||||
continue
|
yield result
|
||||||
|
return
|
||||||
await sync_to_async(save_to_conversation_log)(
|
|
||||||
q,
|
|
||||||
image,
|
|
||||||
user,
|
|
||||||
meta_log,
|
|
||||||
user_message_time,
|
|
||||||
intent_type=intent_type,
|
|
||||||
inferred_queries=[improved_image_prompt],
|
|
||||||
client_application=websocket.user.client_app,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
compiled_references=compiled_references,
|
|
||||||
online_results=online_results,
|
|
||||||
)
|
|
||||||
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore
|
|
||||||
|
|
||||||
await send_complete_llm_response(json.dumps(content_obj))
|
|
||||||
continue
|
|
||||||
|
|
||||||
await send_status_update(f"**💭 Generating a well-informed response**")
|
|
||||||
llm_response, chat_metadata = await agenerate_chat_response(
|
|
||||||
defiltered_query,
|
|
||||||
meta_log,
|
|
||||||
conversation,
|
|
||||||
compiled_references,
|
|
||||||
online_results,
|
|
||||||
inferred_queries,
|
|
||||||
conversation_commands,
|
|
||||||
user,
|
|
||||||
websocket.user.client_app,
|
|
||||||
conversation_id,
|
|
||||||
location,
|
|
||||||
user_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=websocket,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="chat",
|
|
||||||
metadata=chat_metadata,
|
|
||||||
)
|
|
||||||
iterator = AsyncIteratorWrapper(llm_response)
|
|
||||||
|
|
||||||
await send_message("start_llm_response")
|
|
||||||
|
|
||||||
async for item in iterator:
|
|
||||||
if item is None:
|
|
||||||
break
|
|
||||||
if connection_alive:
|
|
||||||
try:
|
|
||||||
await send_message(f"{item}")
|
|
||||||
except ConnectionClosedOK:
|
|
||||||
connection_alive = False
|
|
||||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
|
||||||
|
|
||||||
await send_message("end_llm_response")
|
|
||||||
|
|
||||||
|
|
||||||
@api_chat.get("", response_class=Response)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def chat(
|
|
||||||
request: Request,
|
|
||||||
common: CommonQueryParams,
|
|
||||||
q: str,
|
|
||||||
n: Optional[int] = 5,
|
|
||||||
d: Optional[float] = 0.22,
|
|
||||||
stream: Optional[bool] = False,
|
|
||||||
title: Optional[str] = None,
|
|
||||||
conversation_id: Optional[int] = None,
|
|
||||||
city: Optional[str] = None,
|
|
||||||
region: Optional[str] = None,
|
|
||||||
country: Optional[str] = None,
|
|
||||||
timezone: Optional[str] = None,
|
|
||||||
rate_limiter_per_minute=Depends(
|
|
||||||
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
|
||||||
),
|
|
||||||
rate_limiter_per_day=Depends(
|
|
||||||
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
|
||||||
),
|
|
||||||
) -> Response:
|
|
||||||
user: KhojUser = request.user.object
|
|
||||||
q = unquote(q)
|
|
||||||
if is_query_empty(q):
|
|
||||||
return Response(
|
|
||||||
content="It seems like your query is incomplete. Could you please provide more details or specify what you need help with?",
|
|
||||||
media_type="text/plain",
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
logger.info(f"Chat request by {user.username}: {q}")
|
|
||||||
|
|
||||||
await is_ready_to_chat(user)
|
|
||||||
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
|
||||||
|
|
||||||
_custom_filters = []
|
|
||||||
if conversation_commands == [ConversationCommand.Help]:
|
|
||||||
help_str = "/" + ConversationCommand.Help
|
|
||||||
if q.strip() == help_str:
|
|
||||||
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
|
||||||
if conversation_config == None:
|
|
||||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
|
||||||
model_type = conversation_config.model_type
|
|
||||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
|
||||||
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
|
||||||
# Adding specification to search online specifically on khoj.dev pages.
|
|
||||||
_custom_filters.append("site:khoj.dev")
|
|
||||||
conversation_commands.append(ConversationCommand.Online)
|
|
||||||
|
|
||||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
|
||||||
user, request.user.client_app, conversation_id, title
|
|
||||||
)
|
|
||||||
conversation_id = conversation.id if conversation else None
|
|
||||||
|
|
||||||
if not conversation:
|
|
||||||
return Response(
|
|
||||||
content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
meta_log = conversation.conversation_log
|
|
||||||
|
|
||||||
if ConversationCommand.Summarize in conversation_commands:
|
|
||||||
file_filters = conversation.file_filters
|
|
||||||
llm_response = ""
|
|
||||||
if len(file_filters) == 0:
|
|
||||||
llm_response = "No files selected for summarization. Please add files using the section on the left."
|
|
||||||
elif len(file_filters) > 1:
|
|
||||||
llm_response = "Only one file can be selected for summarization."
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
|
||||||
if len(file_object) == 0:
|
|
||||||
llm_response = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
|
|
||||||
return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
|
|
||||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
|
||||||
summarizeStr = "/" + ConversationCommand.Summarize
|
|
||||||
if q.strip() == summarizeStr:
|
|
||||||
q = "Create a general summary of the file"
|
|
||||||
response = await extract_relevant_summary(q, contextual_data)
|
|
||||||
llm_response = str(response)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error summarizing file for {user.email}: {e}")
|
|
||||||
llm_response = "Error summarizing file."
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
|
||||||
q,
|
|
||||||
llm_response,
|
|
||||||
user,
|
|
||||||
conversation.conversation_log,
|
|
||||||
user_message_time,
|
|
||||||
intent_type="summarize",
|
|
||||||
client_application=request.user.client_app,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
)
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="chat",
|
|
||||||
metadata={"conversation_command": conversation_commands[0].value},
|
|
||||||
**common.__dict__,
|
|
||||||
)
|
|
||||||
return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
|
|
||||||
|
|
||||||
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
|
||||||
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
|
|
||||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
|
|
||||||
if mode not in conversation_commands:
|
|
||||||
conversation_commands.append(mode)
|
|
||||||
|
|
||||||
for cmd in conversation_commands:
|
|
||||||
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
|
||||||
q = q.replace(f"/{cmd.value}", "").strip()
|
|
||||||
|
|
||||||
location = None
|
|
||||||
|
|
||||||
if city or region or country:
|
|
||||||
location = LocationData(city=city, region=region, country=country)
|
|
||||||
|
|
||||||
user_name = await aget_user_name(user)
|
|
||||||
|
|
||||||
if ConversationCommand.Automation in conversation_commands:
|
|
||||||
try:
|
|
||||||
automation, crontime, query_to_run, subject = await create_automation(
|
|
||||||
q, timezone, user, request.url, meta_log
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True)
|
|
||||||
return Response(
|
|
||||||
content=f"Unable to create automation. Ensure the automation doesn't already exist.",
|
|
||||||
media_type="text/plain",
|
|
||||||
status_code=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
|
||||||
q,
|
|
||||||
llm_response,
|
|
||||||
user,
|
|
||||||
meta_log,
|
|
||||||
user_message_time,
|
|
||||||
intent_type="automation",
|
|
||||||
client_application=request.user.client_app,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
inferred_queries=[query_to_run],
|
|
||||||
automation_id=automation.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
|
|
||||||
else:
|
|
||||||
return Response(content=llm_response, media_type="text/plain", status_code=200)
|
|
||||||
|
|
||||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
|
||||||
request, meta_log, q, (n or 5), (d or math.inf), conversation_id, conversation_commands, location
|
|
||||||
)
|
|
||||||
online_results: Dict[str, Dict] = {}
|
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
|
||||||
no_entries_found_format = no_entries_found.format()
|
|
||||||
if stream:
|
|
||||||
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
|
|
||||||
else:
|
|
||||||
response_obj = {"response": no_entries_found_format}
|
|
||||||
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
|
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
|
||||||
no_notes_found_format = no_notes_found.format()
|
|
||||||
if stream:
|
|
||||||
return StreamingResponse(iter([no_notes_found_format]), media_type="text/event-stream", status_code=200)
|
|
||||||
else:
|
|
||||||
response_obj = {"response": no_notes_found_format}
|
|
||||||
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
|
|
||||||
|
|
||||||
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
|
||||||
conversation_commands.remove(ConversationCommand.Notes)
|
|
||||||
|
|
||||||
if ConversationCommand.Online in conversation_commands:
|
|
||||||
try:
|
|
||||||
online_results = await search_online(defiltered_query, meta_log, location, custom_filters=_custom_filters)
|
|
||||||
except ValueError as e:
|
|
||||||
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
|
|
||||||
|
|
||||||
if ConversationCommand.Webpage in conversation_commands:
|
|
||||||
try:
|
|
||||||
online_results = await read_webpages(defiltered_query, meta_log, location)
|
|
||||||
except ValueError as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if ConversationCommand.Image in conversation_commands:
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="chat",
|
|
||||||
metadata={"conversation_command": conversation_commands[0].value},
|
|
||||||
**common.__dict__,
|
|
||||||
)
|
|
||||||
image, status_code, improved_image_prompt, intent_type = await text_to_image(
|
|
||||||
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
|
|
||||||
)
|
|
||||||
if image is None:
|
|
||||||
content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
|
|
||||||
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
|
||||||
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
q,
|
q,
|
||||||
|
@ -1118,14 +888,22 @@ async def chat(
|
||||||
intent_type=intent_type,
|
intent_type=intent_type,
|
||||||
inferred_queries=[improved_image_prompt],
|
inferred_queries=[improved_image_prompt],
|
||||||
client_application=request.user.client_app,
|
client_application=request.user.client_app,
|
||||||
conversation_id=conversation.id,
|
conversation_id=conversation_id,
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
)
|
)
|
||||||
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore
|
content_obj = {
|
||||||
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
"intentType": intent_type,
|
||||||
|
"inferredQueries": [improved_image_prompt],
|
||||||
|
"image": image,
|
||||||
|
}
|
||||||
|
async for result in send_llm_response(json.dumps(content_obj)):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
|
||||||
# Get the (streamed) chat response from the LLM of choice.
|
## Generate Text Output
|
||||||
|
async for result in send_event(ChatEvent.STATUS, f"**💭 Generating a well-informed response**"):
|
||||||
|
yield result
|
||||||
llm_response, chat_metadata = await agenerate_chat_response(
|
llm_response, chat_metadata = await agenerate_chat_response(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
meta_log,
|
meta_log,
|
||||||
|
@ -1136,45 +914,48 @@ async def chat(
|
||||||
conversation_commands,
|
conversation_commands,
|
||||||
user,
|
user,
|
||||||
request.user.client_app,
|
request.user.client_app,
|
||||||
conversation.id,
|
conversation_id,
|
||||||
location,
|
location,
|
||||||
user_name,
|
user_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
cmd_set = set([cmd.value for cmd in conversation_commands])
|
# Send Response
|
||||||
chat_metadata["conversation_command"] = cmd_set
|
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||||
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
|
yield result
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="chat",
|
|
||||||
metadata=chat_metadata,
|
|
||||||
**common.__dict__,
|
|
||||||
)
|
|
||||||
|
|
||||||
if llm_response is None:
|
|
||||||
return Response(content=llm_response, media_type="text/plain", status_code=500)
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
|
|
||||||
|
|
||||||
|
continue_stream = True
|
||||||
iterator = AsyncIteratorWrapper(llm_response)
|
iterator = AsyncIteratorWrapper(llm_response)
|
||||||
|
|
||||||
# Get the full response from the generator if the stream is not requested.
|
|
||||||
aggregated_gpt_response = ""
|
|
||||||
async for item in iterator:
|
async for item in iterator:
|
||||||
if item is None:
|
if item is None:
|
||||||
break
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||||
aggregated_gpt_response += item
|
yield result
|
||||||
|
logger.debug("Finished streaming response")
|
||||||
|
return
|
||||||
|
if not connection_alive or not continue_stream:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
|
||||||
|
yield result
|
||||||
|
except Exception as e:
|
||||||
|
continue_stream = False
|
||||||
|
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
|
||||||
|
|
||||||
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
|
## Stream Text Response
|
||||||
|
if stream:
|
||||||
response_obj = {
|
return StreamingResponse(event_generator(q), media_type="text/plain")
|
||||||
"response": actual_response,
|
## Non-Streaming Text Response
|
||||||
"inferredQueries": inferred_queries,
|
else:
|
||||||
"context": compiled_references,
|
# Get the full response from the generator if the stream is not requested.
|
||||||
"online_results": online_results,
|
response_obj = {}
|
||||||
}
|
actual_response = ""
|
||||||
|
iterator = event_generator(q)
|
||||||
|
async for item in iterator:
|
||||||
|
try:
|
||||||
|
item_json = json.loads(item)
|
||||||
|
if "type" in item_json and item_json["type"] == ChatEvent.REFERENCES.value:
|
||||||
|
response_obj = item_json["data"]
|
||||||
|
except:
|
||||||
|
actual_response += item
|
||||||
|
response_obj["response"] = actual_response
|
||||||
|
|
||||||
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
|
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
|
||||||
|
|
|
@ -8,6 +8,7 @@ import math
|
||||||
import re
|
import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from random import random
|
from random import random
|
||||||
from typing import (
|
from typing import (
|
||||||
|
@ -753,7 +754,7 @@ async def text_to_image(
|
||||||
references: List[Dict[str, Any]],
|
references: List[Dict[str, Any]],
|
||||||
online_results: Dict[str, Any],
|
online_results: Dict[str, Any],
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
) -> Tuple[Optional[str], int, Optional[str], str]:
|
):
|
||||||
status_code = 200
|
status_code = 200
|
||||||
image = None
|
image = None
|
||||||
response = None
|
response = None
|
||||||
|
@ -765,7 +766,8 @@ async def text_to_image(
|
||||||
# If the user has not configured a text to image model, return an unsupported on server error
|
# If the user has not configured a text to image model, return an unsupported on server error
|
||||||
status_code = 501
|
status_code = 501
|
||||||
message = "Failed to generate image. Setup image generation on the server."
|
message = "Failed to generate image. Setup image generation on the server."
|
||||||
return image_url or image, status_code, message, intent_type.value
|
yield image_url or image, status_code, message, intent_type.value
|
||||||
|
return
|
||||||
|
|
||||||
text2image_model = text_to_image_config.model_name
|
text2image_model = text_to_image_config.model_name
|
||||||
chat_history = ""
|
chat_history = ""
|
||||||
|
@ -777,9 +779,9 @@ async def text_to_image(
|
||||||
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
|
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
|
||||||
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
|
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
|
||||||
|
|
||||||
with timer("Improve the original user query", logger):
|
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
|
async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"):
|
||||||
|
yield {ChatEvent.STATUS: event}
|
||||||
improved_image_prompt = await generate_better_image_prompt(
|
improved_image_prompt = await generate_better_image_prompt(
|
||||||
message,
|
message,
|
||||||
chat_history,
|
chat_history,
|
||||||
|
@ -790,7 +792,8 @@ async def text_to_image(
|
||||||
)
|
)
|
||||||
|
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
|
async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"):
|
||||||
|
yield {ChatEvent.STATUS: event}
|
||||||
|
|
||||||
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||||
with timer("Generate image with OpenAI", logger):
|
with timer("Generate image with OpenAI", logger):
|
||||||
|
@ -815,12 +818,14 @@ async def text_to_image(
|
||||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
||||||
return image_url or image, status_code, message, intent_type.value
|
yield image_url or image, status_code, message, intent_type.value
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
return image_url or image, status_code, message, intent_type.value
|
yield image_url or image, status_code, message, intent_type.value
|
||||||
|
return
|
||||||
|
|
||||||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
||||||
with timer("Generate image with Stability AI", logger):
|
with timer("Generate image with Stability AI", logger):
|
||||||
|
@ -842,7 +847,8 @@ async def text_to_image(
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
message = f"Image generation failed with Stability AI error: {e}"
|
message = f"Image generation failed with Stability AI error: {e}"
|
||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
return image_url or image, status_code, message, intent_type.value
|
yield image_url or image, status_code, message, intent_type.value
|
||||||
|
return
|
||||||
|
|
||||||
with timer("Convert image to webp", logger):
|
with timer("Convert image to webp", logger):
|
||||||
# Convert png to webp for faster loading
|
# Convert png to webp for faster loading
|
||||||
|
@ -862,7 +868,7 @@ async def text_to_image(
|
||||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||||
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
||||||
|
|
||||||
return image_url or image, status_code, improved_image_prompt, intent_type.value
|
yield image_url or image, status_code, improved_image_prompt, intent_type.value
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
|
@ -1184,3 +1190,11 @@ def construct_automation_created_message(automation: Job, crontime: str, query_t
|
||||||
|
|
||||||
Manage your automations [here](/automations).
|
Manage your automations [here](/automations).
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatEvent(Enum):
|
||||||
|
START_LLM_RESPONSE = "start_llm_response"
|
||||||
|
END_LLM_RESPONSE = "end_llm_response"
|
||||||
|
MESSAGE = "message"
|
||||||
|
REFERENCES = "references"
|
||||||
|
STATUS = "status"
|
||||||
|
|
|
@ -22,7 +22,7 @@ magika = Magika()
|
||||||
|
|
||||||
|
|
||||||
def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict:
|
def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict:
|
||||||
files = {}
|
files: dict[str, dict] = {"docx": {}, "image": {}}
|
||||||
|
|
||||||
if search_type == SearchType.All or search_type == SearchType.Org:
|
if search_type == SearchType.All or search_type == SearchType.Org:
|
||||||
org_config = LocalOrgConfig.objects.filter(user=user).first()
|
org_config = LocalOrgConfig.objects.filter(user=user).first()
|
||||||
|
|
|
@ -455,13 +455,13 @@ def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiU
|
||||||
|
|
||||||
@pytest.mark.skipif(os.getenv("OPENAI_API_KEY") is None, reason="requires OPENAI_API_KEY")
|
@pytest.mark.skipif(os.getenv("OPENAI_API_KEY") is None, reason="requires OPENAI_API_KEY")
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser):
|
async def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
headers = {"Authorization": f"Bearer {api_user2.token}"}
|
headers = {"Authorization": f"Bearer {api_user2.token}"}
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true', headers=headers)
|
auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"', headers=headers)
|
||||||
no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true')
|
no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"')
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert auth_response.status_code == 200
|
assert auth_response.status_code == 200
|
||||||
|
|
|
@ -67,10 +67,8 @@ def test_chat_with_online_content(client_offline_chat):
|
||||||
# Act
|
# Act
|
||||||
q = "/online give me the link to paul graham's essay how to do great work"
|
q = "/online give me the link to paul graham's essay how to do great work"
|
||||||
encoded_q = quote(q, safe="")
|
encoded_q = quote(q, safe="")
|
||||||
response = client_offline_chat.get(f"/api/chat?q={encoded_q}&stream=true")
|
response = client_offline_chat.get(f"/api/chat?q={encoded_q}")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
response_message = response_message.split("### compiled references")[0]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = [
|
expected_responses = [
|
||||||
|
@ -91,10 +89,8 @@ def test_chat_with_online_webpage_content(client_offline_chat):
|
||||||
# Act
|
# Act
|
||||||
q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
|
q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
|
||||||
encoded_q = quote(q, safe="")
|
encoded_q = quote(q, safe="")
|
||||||
response = client_offline_chat.get(f"/api/chat?q={encoded_q}&stream=true")
|
response = client_offline_chat.get(f"/api/chat?q={encoded_q}")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
response_message = response_message.split("### compiled references")[0]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["185", "1871", "horse"]
|
expected_responses = ["185", "1871", "horse"]
|
||||||
|
|
|
@ -49,8 +49,8 @@ def create_conversation(message_list, user, agent=None):
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
|
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["Khoj", "khoj"]
|
expected_responses = ["Khoj", "khoj"]
|
||||||
|
@ -67,10 +67,8 @@ def test_chat_with_online_content(chat_client):
|
||||||
# Act
|
# Act
|
||||||
q = "/online give me the link to paul graham's essay how to do great work"
|
q = "/online give me the link to paul graham's essay how to do great work"
|
||||||
encoded_q = quote(q, safe="")
|
encoded_q = quote(q, safe="")
|
||||||
response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true")
|
response = chat_client.get(f"/api/chat?q={encoded_q}")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
response_message = response_message.split("### compiled references")[0]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = [
|
expected_responses = [
|
||||||
|
@ -91,10 +89,8 @@ def test_chat_with_online_webpage_content(chat_client):
|
||||||
# Act
|
# Act
|
||||||
q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
|
q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
|
||||||
encoded_q = quote(q, safe="")
|
encoded_q = quote(q, safe="")
|
||||||
response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true")
|
response = chat_client.get(f"/api/chat?q={encoded_q}")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
response_message = response_message.split("### compiled references")[0]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["185", "1871", "horse"]
|
expected_responses = ["185", "1871", "horse"]
|
||||||
|
@ -144,7 +140,7 @@ def test_answer_from_currently_retrieved_content(chat_client, default_user2: Kho
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
|
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
@ -168,7 +164,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_n
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"')
|
response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
@ -191,7 +187,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, d
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
@ -215,8 +211,8 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use
|
||||||
create_conversation(message_list, default_user2)
|
create_conversation(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true')
|
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = [
|
expected_responses = [
|
||||||
|
@ -226,6 +222,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use
|
||||||
"do not have",
|
"do not have",
|
||||||
"don't have",
|
"don't have",
|
||||||
"where were you born?",
|
"where were you born?",
|
||||||
|
"where you were born?",
|
||||||
]
|
]
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
@ -280,8 +277,8 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default
|
||||||
create_conversation(message_list, default_user2)
|
create_conversation(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true")
|
response = chat_client_no_background.get(f"/api/chat?q={query}")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
@ -527,8 +524,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
||||||
create_conversation(message_list, default_user2)
|
create_conversation(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else."&stream=true')
|
response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else.')
|
||||||
response_message = response.content.decode("utf-8").split("### compiled references")[0]
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["test", "Test"]
|
expected_responses = ["test", "Test"]
|
||||||
|
@ -544,9 +541,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background):
|
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background):
|
||||||
# Act
|
# Act
|
||||||
|
response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son?"')
|
||||||
response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son?"&stream=true')
|
response_message = response.json()["response"].lower()
|
||||||
response_message = response.content.decode("utf-8").split("### compiled references")[0].lower()
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = [
|
expected_responses = [
|
||||||
|
@ -658,8 +654,8 @@ def test_answer_in_chat_history_by_conversation_id_with_agent(
|
||||||
def test_answer_requires_multiple_independent_searches(chat_client):
|
def test_answer_requires_multiple_independent_searches(chat_client):
|
||||||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"&stream=true')
|
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"')
|
||||||
response_message = response.content.decode("utf-8").split("### compiled references")[0].lower()
|
response_message = response.json()["response"].lower()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
|
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
|
||||||
|
@ -683,8 +679,8 @@ def test_answer_using_file_filter(chat_client):
|
||||||
'Is Xi older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"'
|
'Is Xi older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"'
|
||||||
)
|
)
|
||||||
|
|
||||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
response = chat_client.get(f"/api/chat?q={query}")
|
||||||
response_message = response.content.decode("utf-8").split("### compiled references")[0].lower()
|
response_message = response.json()["response"].lower()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
|
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
|
||||||
|
|
Loading…
Reference in a new issue