mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Merge pull request #679 from khoj-ai/features/chat-socket-streaming
Add a websocket for streaming from the chat UI
This commit is contained in:
commit
9c42c8be6b
4 changed files with 603 additions and 121 deletions
|
@ -75,6 +75,7 @@ dependencies = [
|
|||
"django-phonenumber-field == 7.3.0",
|
||||
"phonenumbers == 8.13.27",
|
||||
"markdownify ~= 0.11.6",
|
||||
"websockets == 12.0",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
|
|
|
@ -47,11 +47,22 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
}, 1000);
|
||||
});
|
||||
}
|
||||
var websocket = null;
|
||||
var timeout = null;
|
||||
var timeoutDuration = 600000; // 10 minutes
|
||||
|
||||
let region = null;
|
||||
let city = null;
|
||||
let countryName = null;
|
||||
|
||||
let websocketState = {
|
||||
newResponseText: null,
|
||||
newResponseElement: null,
|
||||
loadingEllipsis: null,
|
||||
references: {},
|
||||
rawResponse: "",
|
||||
}
|
||||
|
||||
fetch("https://ipapi.co/json")
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
|
@ -415,6 +426,12 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
|
||||
async function chat() {
|
||||
// Extract required fields for search from form
|
||||
|
||||
if (websocket) {
|
||||
sendMessageViaWebSocket();
|
||||
return;
|
||||
}
|
||||
|
||||
let query = document.getElementById("chat-input").value.trim();
|
||||
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
|
||||
console.log(`Query: ${query}`);
|
||||
|
@ -440,9 +457,6 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
refreshChatSessionsPanel();
|
||||
}
|
||||
|
||||
// Generate backend API URL to execute query
|
||||
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}`;
|
||||
|
||||
let new_response = document.createElement("div");
|
||||
new_response.classList.add("chat-message", "khoj");
|
||||
new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
|
||||
|
@ -452,6 +466,79 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
newResponseText.classList.add("chat-message-text", "khoj");
|
||||
new_response.appendChild(newResponseText);
|
||||
|
||||
// Temporary status message to indicate that Khoj is thinking
|
||||
let loadingEllipsis = createLoadingEllipse();
|
||||
|
||||
newResponseText.appendChild(loadingEllipsis);
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
|
||||
let chatTooltip = document.getElementById("chat-tooltip");
|
||||
chatTooltip.style.display = "none";
|
||||
|
||||
let chatInput = document.getElementById("chat-input");
|
||||
chatInput.classList.remove("option-enabled");
|
||||
|
||||
// Generate backend API URL to execute query
|
||||
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}`;
|
||||
|
||||
// Call specified Khoj API
|
||||
let response = await fetch(url);
|
||||
let rawResponse = "";
|
||||
let references = null;
|
||||
const contentType = response.headers.get("content-type");
|
||||
|
||||
if (contentType === "application/json") {
|
||||
// Handle JSON response
|
||||
try {
|
||||
const responseAsJson = await response.json();
|
||||
if (responseAsJson.image || responseAsJson.detail) {
|
||||
({rawResponse, references } = handleImageResponse(responseAsJson, rawResponse));
|
||||
} else {
|
||||
rawResponse = responseAsJson.response;
|
||||
}
|
||||
} catch (error) {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
rawResponse += chunk;
|
||||
} finally {
|
||||
addMessageToChatBody(rawResponse, newResponseText, references);
|
||||
}
|
||||
} else {
|
||||
// Handle streamed response of type text/event-stream or text/plain
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let references = {};
|
||||
|
||||
readStream();
|
||||
|
||||
function readStream() {
|
||||
reader.read().then(({ done, value }) => {
|
||||
if (done) {
|
||||
// Append any references after all the data has been streamed
|
||||
finalizeChatBodyResponse(references, newResponseText);
|
||||
return;
|
||||
}
|
||||
|
||||
// Decode message chunk from stream
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
|
||||
if (chunk.includes("### compiled references:")) {
|
||||
({ rawResponse, references } = handleCompiledReferences(newResponseText, chunk, references, rawResponse));
|
||||
readStream();
|
||||
} else {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
rawResponse += chunk;
|
||||
handleStreamResponse(newResponseText, rawResponse, loadingEllipsis);
|
||||
readStream();
|
||||
}
|
||||
});
|
||||
|
||||
// Scroll to bottom of chat window as chat response is streamed
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
function createLoadingEllipse() {
|
||||
// Temporary status message to indicate that Khoj is thinking
|
||||
let loadingEllipsis = document.createElement("div");
|
||||
loadingEllipsis.classList.add("lds-ellipsis");
|
||||
|
@ -473,115 +560,80 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
loadingEllipsis.appendChild(thirdEllipsis);
|
||||
loadingEllipsis.appendChild(fourthEllipsis);
|
||||
|
||||
newResponseText.appendChild(loadingEllipsis);
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
return loadingEllipsis;
|
||||
}
|
||||
|
||||
let chatTooltip = document.getElementById("chat-tooltip");
|
||||
chatTooltip.style.display = "none";
|
||||
|
||||
let chatInput = document.getElementById("chat-input");
|
||||
chatInput.classList.remove("option-enabled");
|
||||
|
||||
// Call specified Khoj API
|
||||
let response = await fetch(url);
|
||||
let rawResponse = "";
|
||||
let references = null;
|
||||
const contentType = response.headers.get("content-type");
|
||||
|
||||
if (contentType === "application/json") {
|
||||
// Handle JSON response
|
||||
try {
|
||||
const responseAsJson = await response.json();
|
||||
if (responseAsJson.image) {
|
||||
// If response has image field, response is a generated image.
|
||||
if (responseAsJson.intentType === "text-to-image") {
|
||||
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
|
||||
} else if (responseAsJson.intentType === "text-to-image2") {
|
||||
rawResponse += `![${query}](${responseAsJson.image})`;
|
||||
}
|
||||
const inferredQuery = responseAsJson.inferredQueries?.[0];
|
||||
if (inferredQuery) {
|
||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
}
|
||||
}
|
||||
if (responseAsJson.context && responseAsJson.context.length > 0) {
|
||||
const rawReferenceAsJson = responseAsJson.context;
|
||||
references = createReferenceSection(rawReferenceAsJson);
|
||||
}
|
||||
if (responseAsJson.detail) {
|
||||
// If response has detail field, response is an error message.
|
||||
rawResponse += responseAsJson.detail;
|
||||
}
|
||||
} catch (error) {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
rawResponse += chunk;
|
||||
} finally {
|
||||
newResponseText.innerHTML = "";
|
||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||
|
||||
if (references != null) {
|
||||
newResponseText.appendChild(references);
|
||||
}
|
||||
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
document.getElementById("chat-input").removeAttribute("disabled");
|
||||
}
|
||||
} else {
|
||||
// Handle streamed response of type text/event-stream or text/plain
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let references = {};
|
||||
|
||||
readStream();
|
||||
|
||||
function readStream() {
|
||||
reader.read().then(({ done, value }) => {
|
||||
if (done) {
|
||||
// Append any references after all the data has been streamed
|
||||
if (references != {}) {
|
||||
newResponseText.appendChild(createReferenceSection(references));
|
||||
}
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
document.getElementById("chat-input").removeAttribute("disabled");
|
||||
return;
|
||||
}
|
||||
|
||||
// Decode message chunk from stream
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
|
||||
if (chunk.includes("### compiled references:")) {
|
||||
const additionalResponse = chunk.split("### compiled references:")[0];
|
||||
rawResponse += additionalResponse;
|
||||
newResponseText.innerHTML = "";
|
||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||
|
||||
const rawReference = chunk.split("### compiled references:")[1];
|
||||
const rawReferenceAsJson = JSON.parse(rawReference);
|
||||
if (rawReferenceAsJson instanceof Array) {
|
||||
references["notes"] = rawReferenceAsJson;
|
||||
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
||||
references["online"] = rawReferenceAsJson;
|
||||
}
|
||||
readStream();
|
||||
} else {
|
||||
// Display response from Khoj
|
||||
if (newResponseText.getElementsByClassName("lds-ellipsis").length > 0) {
|
||||
newResponseText.removeChild(loadingEllipsis);
|
||||
}
|
||||
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
rawResponse += chunk;
|
||||
newResponseText.innerHTML = "";
|
||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||
readStream();
|
||||
}
|
||||
});
|
||||
|
||||
// Scroll to bottom of chat window as chat response is streamed
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
};
|
||||
function handleStreamResponse(newResponseElement, rawResponse, loadingEllipsis, replace=true) {
|
||||
if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) {
|
||||
newResponseElement.removeChild(loadingEllipsis);
|
||||
}
|
||||
};
|
||||
if (replace) {
|
||||
newResponseElement.innerHTML = "";
|
||||
}
|
||||
newResponseElement.appendChild(formatHTMLMessage(rawResponse));
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
}
|
||||
|
||||
function handleCompiledReferences(rawResponseElement, chunk, references, rawResponse) {
|
||||
const additionalResponse = chunk.split("### compiled references:")[0];
|
||||
rawResponse += additionalResponse;
|
||||
rawResponseElement.innerHTML = "";
|
||||
rawResponseElement.appendChild(formatHTMLMessage(rawResponse));
|
||||
|
||||
const rawReference = chunk.split("### compiled references:")[1];
|
||||
const rawReferenceAsJson = JSON.parse(rawReference);
|
||||
if (rawReferenceAsJson instanceof Array) {
|
||||
references["notes"] = rawReferenceAsJson;
|
||||
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
||||
references["online"] = rawReferenceAsJson;
|
||||
}
|
||||
return { rawResponse, references };
|
||||
}
|
||||
|
||||
function handleImageResponse(imageJson, rawResponse) {
|
||||
if (imageJson.image) {
|
||||
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
|
||||
|
||||
// If response has image field, response is a generated image.
|
||||
if (imageJson.intentType === "text-to-image") {
|
||||
rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`;
|
||||
} else if (imageJson.intentType === "text-to-image2") {
|
||||
rawResponse += `![generated_image](${imageJson.image})`;
|
||||
}
|
||||
if (inferredQuery) {
|
||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
}
|
||||
}
|
||||
let references = {};
|
||||
if (imageJson.context && imageJson.context.length > 0) {
|
||||
const rawReferenceAsJson = imageJson.context;
|
||||
if (rawReferenceAsJson instanceof Array) {
|
||||
references["notes"] = rawReferenceAsJson;
|
||||
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
||||
references["online"] = rawReferenceAsJson;
|
||||
}
|
||||
}
|
||||
if (imageJson.detail) {
|
||||
// If response has detail field, response is an error message.
|
||||
rawResponse += imageJson.detail;
|
||||
}
|
||||
return { rawResponse, references };
|
||||
}
|
||||
|
||||
function addMessageToChatBody(rawResponse, newResponseElement, references) {
|
||||
newResponseElement.innerHTML = "";
|
||||
newResponseElement.appendChild(formatHTMLMessage(rawResponse));
|
||||
|
||||
finalizeChatBodyResponse(references, newResponseElement);
|
||||
}
|
||||
|
||||
function finalizeChatBodyResponse(references, newResponseElement) {
|
||||
if (references != null && Object.keys(references).length > 0) {
|
||||
newResponseElement.appendChild(createReferenceSection(references));
|
||||
}
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
document.getElementById("chat-input").removeAttribute("disabled");
|
||||
}
|
||||
|
||||
function incrementalChat(event) {
|
||||
if (!event.shiftKey && event.key === 'Enter') {
|
||||
|
@ -798,6 +850,180 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
|
||||
window.onload = loadChat;
|
||||
|
||||
function setupWebSocket() {
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
let wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
let webSocketUrl = `${wsProtocol}//${window.location.host}/api/chat/ws`;
|
||||
|
||||
websocketState = {
|
||||
newResponseText: null,
|
||||
newResponseElement: null,
|
||||
loadingEllipsis: null,
|
||||
references: {},
|
||||
rawResponse: "",
|
||||
}
|
||||
|
||||
function resetTimeout() {
|
||||
if (timeout) {
|
||||
clearTimeout(timeout);
|
||||
}
|
||||
|
||||
timeout = setTimeout(function() {
|
||||
if (websocket) {
|
||||
websocket.close();
|
||||
}
|
||||
}, timeoutDuration);
|
||||
}
|
||||
|
||||
if (chatBody.dataset.conversationId) {
|
||||
webSocketUrl += `?conversation_id=${chatBody.dataset.conversationId}`;
|
||||
webSocketUrl += `®ion=${region}&city=${city}&country=${countryName}`;
|
||||
|
||||
websocket = new WebSocket(webSocketUrl);
|
||||
websocket.onmessage = function(event) {
|
||||
resetTimeout();
|
||||
|
||||
// Get the last element in the chat-body
|
||||
let chunk = event.data;
|
||||
if (chunk == "start_llm_response") {
|
||||
console.log("Started streaming", new Date());
|
||||
} else if(chunk == "end_llm_response") {
|
||||
console.log("Stopped streaming", new Date());
|
||||
// Append any references after all the data has been streamed
|
||||
finalizeChatBodyResponse(websocketState.references, websocketState.newResponseText);
|
||||
|
||||
// Reset variables
|
||||
websocketState = {
|
||||
newResponseText: null,
|
||||
newResponseElement: null,
|
||||
loadingEllipsis: null,
|
||||
references: {},
|
||||
rawResponse: "",
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
if (chunk.includes("application/json"))
|
||||
{
|
||||
chunk = JSON.parse(chunk);
|
||||
}
|
||||
} catch (error) {
|
||||
// If the chunk is not a JSON object, continue.
|
||||
}
|
||||
|
||||
const contentType = chunk["content-type"]
|
||||
|
||||
if (contentType === "application/json") {
|
||||
// Handle JSON response
|
||||
try {
|
||||
if (chunk.image || chunk.detail) {
|
||||
({rawResponse, references } = handleImageResponse(chunk, websocketState.rawResponse));
|
||||
websocketState.rawResponse = rawResponse;
|
||||
websocketState.references = references;
|
||||
} else if (chunk.type == "status") {
|
||||
handleStreamResponse(websocketState.newResponseText, chunk.message, null, false);
|
||||
} else {
|
||||
rawResponse = chunk.response;
|
||||
}
|
||||
} catch (error) {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
websocketState.rawResponse += chunk;
|
||||
} finally {
|
||||
if (chunk.type != "status") {
|
||||
addMessageToChatBody(websocketState.rawResponse, websocketState.newResponseText, websocketState.references);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
// Handle streamed response of type text/event-stream or text/plain
|
||||
if (chunk && chunk.includes("### compiled references:")) {
|
||||
({ rawResponse, references } = handleCompiledReferences(websocketState.newResponseText, chunk, websocketState.references, websocketState.rawResponse));
|
||||
websocketState.rawResponse = rawResponse;
|
||||
websocketState.references = references;
|
||||
} else {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
websocketState.rawResponse += chunk;
|
||||
if (websocketState.newResponseText) {
|
||||
handleStreamResponse(websocketState.newResponseText, websocketState.rawResponse, websocketState.loadingEllipsis);
|
||||
}
|
||||
}
|
||||
|
||||
// Scroll to bottom of chat window as chat response is streamed
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
};
|
||||
}
|
||||
}
|
||||
};
|
||||
websocket.onclose = function(event) {
|
||||
websocket = null;
|
||||
console.log("WebSocket is closed now.");
|
||||
let greenDot = document.getElementById("connected-green-dot");
|
||||
greenDot.style.display = "none";
|
||||
}
|
||||
websocket.onerror = function(event) {
|
||||
console.log("WebSocket error observed:", event);
|
||||
}
|
||||
|
||||
websocket.onopen = function(event) {
|
||||
console.log("WebSocket is open now.")
|
||||
let greenDot = document.getElementById("connected-green-dot");
|
||||
greenDot.style.display = "flex";
|
||||
|
||||
// Setup the timeout to close the connection after inactivity.
|
||||
resetTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
function sendMessageViaWebSocket(event) {
|
||||
if (event) {
|
||||
event.preventDefault();
|
||||
}
|
||||
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
|
||||
var query = document.getElementById("chat-input").value.trim();
|
||||
console.log(`Query: ${query}`);
|
||||
|
||||
// Add message by user to chat body
|
||||
renderMessage(query, "you");
|
||||
document.getElementById("chat-input").value = "";
|
||||
autoResize();
|
||||
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
||||
|
||||
let newResponseElement = document.createElement("div");
|
||||
newResponseElement.classList.add("chat-message", "khoj");
|
||||
newResponseElement.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
|
||||
chatBody.appendChild(newResponseElement);
|
||||
|
||||
let newResponseText = document.createElement("div");
|
||||
newResponseText.classList.add("chat-message-text", "khoj");
|
||||
newResponseElement.appendChild(newResponseText);
|
||||
|
||||
// Temporary status message to indicate that Khoj is thinking
|
||||
let loadingEllipsis = createLoadingEllipse();
|
||||
|
||||
newResponseText.appendChild(loadingEllipsis);
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
|
||||
let chatTooltip = document.getElementById("chat-tooltip");
|
||||
chatTooltip.style.display = "none";
|
||||
|
||||
let chatInput = document.getElementById("chat-input");
|
||||
chatInput.classList.remove("option-enabled");
|
||||
|
||||
// Call specified Khoj API
|
||||
websocket.send(query);
|
||||
let rawResponse = "";
|
||||
let references = {};
|
||||
|
||||
websocketState = {
|
||||
newResponseText,
|
||||
newResponseElement,
|
||||
loadingEllipsis,
|
||||
references,
|
||||
rawResponse,
|
||||
}
|
||||
}
|
||||
|
||||
function loadChat() {
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
chatBody.innerHTML = "";
|
||||
|
@ -805,6 +1031,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
let chatHistoryUrl = `/api/chat/history?client=web`;
|
||||
if (chatBody.dataset.conversationId) {
|
||||
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
|
||||
setupWebSocket();
|
||||
}
|
||||
|
||||
if (window.screen.width < 700) {
|
||||
|
@ -841,6 +1068,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
// Render conversation history, if any
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
chatBody.dataset.conversationId = response.conversation_id;
|
||||
setupWebSocket();
|
||||
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
|
||||
|
||||
let agentMetadata = response.agent;
|
||||
|
@ -1323,6 +1551,10 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
<div id="side-panel-wrapper">
|
||||
<div id="side-panel">
|
||||
<div id="new-conversation">
|
||||
<div id="connected-green-dot" style="display: none; align-items: center; margin-bottom: 10px;">
|
||||
<div style="width: 10px; height: 10px; background-color: green; border-radius: 50%; margin-right: 5px;"></div>
|
||||
<div>Connected</div>
|
||||
</div>
|
||||
<button class="side-panel-button" id="new-conversation-button" onclick="createNewConversation()">
|
||||
New Topic
|
||||
<svg class="new-convo-button" viewBox="0 0 35 35" fill="#000000" viewBox="0 0 32 32" version="1.1" xmlns="http://www.w3.org/2000/svg">
|
||||
|
|
|
@ -61,6 +61,36 @@ async def search(
|
|||
dedupe: Optional[bool] = True,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
results = await execute_search(
|
||||
user=user,
|
||||
q=q,
|
||||
n=n,
|
||||
t=t,
|
||||
r=r,
|
||||
max_distance=max_distance,
|
||||
dedupe=dedupe,
|
||||
)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="search",
|
||||
**common.__dict__,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def execute_search(
|
||||
user: KhojUser,
|
||||
q: str,
|
||||
n: Optional[int] = 5,
|
||||
t: Optional[SearchType] = SearchType.All,
|
||||
r: Optional[bool] = False,
|
||||
max_distance: Optional[Union[float, None]] = None,
|
||||
dedupe: Optional[bool] = True,
|
||||
):
|
||||
start_time = time.time()
|
||||
|
||||
# Run validation checks
|
||||
|
@ -155,13 +185,6 @@ async def search(
|
|||
if user:
|
||||
state.query_cache[user.uuid][query_cache_key] = results
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="search",
|
||||
**common.__dict__,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
||||
|
||||
|
@ -350,14 +373,14 @@ async def extract_references_and_questions(
|
|||
for query in inferred_queries:
|
||||
n_items = min(n, 3) if using_offline_chat else n
|
||||
result_list.extend(
|
||||
await search(
|
||||
await execute_search(
|
||||
user,
|
||||
f"{query} {filters_in_query}",
|
||||
request=request,
|
||||
n=n_items,
|
||||
t=SearchType.All,
|
||||
r=True,
|
||||
max_distance=d,
|
||||
dedupe=False,
|
||||
common=common,
|
||||
)
|
||||
)
|
||||
result_list = text_search.deduplicated_search_responses(result_list)
|
||||
|
|
|
@ -5,10 +5,12 @@ from typing import Dict, Optional
|
|||
from urllib.parse import unquote
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi import APIRouter, Depends, Request, WebSocket
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from starlette.authentication import requires
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
from websockets import ConnectionClosedOK
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
|
||||
from khoj.database.models import KhojUser
|
||||
|
@ -242,6 +244,230 @@ async def set_conversation_title(
|
|||
)
|
||||
|
||||
|
||||
@api_chat.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
conversation_id: int,
|
||||
city: Optional[str] = None,
|
||||
region: Optional[str] = None,
|
||||
country: Optional[str] = None,
|
||||
):
|
||||
connection_alive = True
|
||||
|
||||
async def send_status_update(message: str):
|
||||
nonlocal connection_alive
|
||||
if not connection_alive:
|
||||
return
|
||||
|
||||
status_packet = {
|
||||
"type": "status",
|
||||
"message": message,
|
||||
"content-type": "application/json",
|
||||
}
|
||||
try:
|
||||
await websocket.send_text(json.dumps(status_packet))
|
||||
except ConnectionClosedOK:
|
||||
connection_alive = False
|
||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||
|
||||
async def send_complete_llm_response(llm_response: str):
|
||||
nonlocal connection_alive
|
||||
if not connection_alive:
|
||||
return
|
||||
try:
|
||||
await websocket.send_text("start_llm_response")
|
||||
await websocket.send_text(llm_response)
|
||||
await websocket.send_text("end_llm_response")
|
||||
except ConnectionClosedOK:
|
||||
connection_alive = False
|
||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||
|
||||
async def send_message(message: str):
|
||||
nonlocal connection_alive
|
||||
if not connection_alive:
|
||||
return
|
||||
try:
|
||||
await websocket.send_text(message)
|
||||
except ConnectionClosedOK:
|
||||
connection_alive = False
|
||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||
|
||||
user: KhojUser = websocket.user.object
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||
user, client_application=websocket.user.client_app, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
||||
|
||||
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||
|
||||
await is_ready_to_chat(user)
|
||||
|
||||
user_name = await aget_user_name(user)
|
||||
|
||||
location = None
|
||||
|
||||
if city or region or country:
|
||||
location = LocationData(city=city, region=region, country=country)
|
||||
|
||||
await websocket.accept()
|
||||
while connection_alive:
|
||||
try:
|
||||
q = await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
logger.debug(f"User {user} disconnected web socket")
|
||||
break
|
||||
|
||||
await sync_to_async(hourly_limiter)(websocket)
|
||||
await sync_to_async(daily_limiter)(websocket)
|
||||
|
||||
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||
|
||||
await send_status_update(f"**Processing query**: {q}")
|
||||
|
||||
if conversation_commands == [ConversationCommand.Help]:
|
||||
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||
if conversation_config == None:
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
model_type = conversation_config.model_type
|
||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||
await send_complete_llm_response(formatted_help)
|
||||
continue
|
||||
|
||||
meta_log = conversation.conversation_log
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default]:
|
||||
conversation_commands = await aget_relevant_information_sources(q, meta_log)
|
||||
mode = await aget_relevant_output_modes(q, meta_log)
|
||||
if mode not in conversation_commands:
|
||||
conversation_commands.append(mode)
|
||||
|
||||
for cmd in conversation_commands:
|
||||
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
|
||||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
|
||||
await send_status_update(
|
||||
f"**Using conversation commands:** {', '.join([cmd.value for cmd in conversation_commands])}"
|
||||
)
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
websocket, None, meta_log, q, 7, 0.18, conversation_commands, location
|
||||
)
|
||||
|
||||
if compiled_references:
|
||||
headings = set([c.split("\n")[0] for c in compiled_references])
|
||||
await send_status_update(f"**Searching references**: {headings}")
|
||||
|
||||
online_results: Dict = dict()
|
||||
|
||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||
await send_complete_llm_response(f"{no_entries_found.format()}")
|
||||
continue
|
||||
|
||||
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
||||
conversation_commands.remove(ConversationCommand.Notes)
|
||||
|
||||
if ConversationCommand.Online in conversation_commands:
|
||||
if not online_search_enabled():
|
||||
conversation_commands.remove(ConversationCommand.Online)
|
||||
# If online search is not enabled, try to read webpages directly
|
||||
if ConversationCommand.Webpage not in conversation_commands:
|
||||
conversation_commands.append(ConversationCommand.Webpage)
|
||||
else:
|
||||
try:
|
||||
await send_status_update("**Operation**: Searching the web for relevant information...")
|
||||
online_results = await search_online(defiltered_query, meta_log, location)
|
||||
online_searches = ", ".join([f"{query}" for query in online_results.keys()])
|
||||
await send_status_update(f"**Online searches**: {online_searches}")
|
||||
except ValueError as e:
|
||||
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
|
||||
await send_complete_llm_response(
|
||||
f"Error searching online: {e}. Attempting to respond without online results"
|
||||
)
|
||||
continue
|
||||
|
||||
if ConversationCommand.Image in conversation_commands:
|
||||
update_telemetry_state(
|
||||
request=websocket,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
)
|
||||
await send_status_update("**Operation**: Augmenting your query and generating a superb image...")
|
||||
intent_type = "text-to-image"
|
||||
image, status_code, improved_image_prompt, image_url = await text_to_image(
|
||||
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
|
||||
)
|
||||
if image is None or status_code != 200:
|
||||
content_obj = {
|
||||
"image": image,
|
||||
"intentType": intent_type,
|
||||
"detail": improved_image_prompt,
|
||||
"content-type": "application/json",
|
||||
}
|
||||
await send_complete_llm_response(json.dumps(content_obj))
|
||||
continue
|
||||
|
||||
if image_url:
|
||||
intent_type = "text-to-image2"
|
||||
image = image_url
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
image,
|
||||
user,
|
||||
meta_log,
|
||||
intent_type=intent_type,
|
||||
inferred_queries=[improved_image_prompt],
|
||||
client_application=websocket.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
)
|
||||
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore
|
||||
|
||||
await send_complete_llm_response(json.dumps(content_obj))
|
||||
continue
|
||||
|
||||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
conversation,
|
||||
compiled_references,
|
||||
online_results,
|
||||
inferred_queries,
|
||||
conversation_commands,
|
||||
user,
|
||||
websocket.user.client_app,
|
||||
conversation_id,
|
||||
location,
|
||||
user_name,
|
||||
)
|
||||
|
||||
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
|
||||
|
||||
update_telemetry_state(
|
||||
request=websocket,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata=chat_metadata,
|
||||
)
|
||||
iterator = AsyncIteratorWrapper(llm_response)
|
||||
|
||||
await send_message("start_llm_response")
|
||||
|
||||
async for item in iterator:
|
||||
if item is None:
|
||||
break
|
||||
if connection_alive:
|
||||
try:
|
||||
await send_message(f"{item}")
|
||||
except ConnectionClosedOK:
|
||||
connection_alive = False
|
||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||
|
||||
await send_message("end_llm_response")
|
||||
|
||||
|
||||
@api_chat.get("", response_class=Response)
|
||||
@requires(["authenticated"])
|
||||
async def chat(
|
||||
|
|
Loading…
Add table
Reference in a new issue