Convert Websocket into Server Side Event (SSE) API endpoint

- Convert functions in SSE API path into async generators using yields
- Validate image generation, online, notes lookup and general paths of
  chat request are handled fine by the web client and server API
This commit is contained in:
Debanjum Singh Solanky 2024-07-21 12:10:13 +05:30
parent e694c82343
commit 91fe41106e
6 changed files with 577 additions and 489 deletions

View file

@ -40,6 +40,7 @@ dependencies = [
"dateparser >= 1.1.1",
"defusedxml == 0.7.1",
"fastapi >= 0.104.1",
"sse-starlette ~= 2.1.0",
"python-multipart >= 0.0.7",
"jinja2 == 3.1.4",
"openai >= 1.0.0",

View file

@ -74,14 +74,14 @@ To get started, just start typing below. You can also type / to see a list of co
}, 1000);
});
}
var websocket = null;
var sseConnection = null;
let region = null;
let city = null;
let countryName = null;
let timezone = null;
let waitingForLocation = true;
let websocketState = {
let chatMessageState = {
newResponseTextEl: null,
newResponseEl: null,
loadingEllipsis: null,
@ -105,7 +105,7 @@ To get started, just start typing below. You can also type / to see a list of co
.finally(() => {
console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone);
waitingForLocation = false;
setupWebSocket();
initializeSSE();
});
function formatDate(date) {
@ -599,10 +599,8 @@ To get started, just start typing below. You can also type / to see a list of co
}
async function chat(isVoice=false) {
if (websocket) {
sendMessageViaWebSocket(isVoice);
return;
}
sendMessageViaSSE(isVoice);
return;
let query = document.getElementById("chat-input").value.trim();
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
@ -1069,17 +1067,13 @@ To get started, just start typing below. You can also type / to see a list of co
window.onload = loadChat;
function setupWebSocket(isVoice=false) {
let chatBody = document.getElementById("chat-body");
let wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
let webSocketUrl = `${wsProtocol}//${window.location.host}/api/chat/ws`;
function initializeSSE(isVoice=false) {
if (waitingForLocation) {
console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available.");
return;
}
websocketState = {
chatMessageState = {
newResponseTextEl: null,
newResponseEl: null,
loadingEllipsis: null,
@ -1088,121 +1082,138 @@ To get started, just start typing below. You can also type / to see a list of co
rawQuery: "",
isVoice: isVoice,
}
}
function sendSSEMessage(query) {
let chatBody = document.getElementById("chat-body");
let sseProtocol = window.location.protocol;
let sseUrl = `/api/chat/stream?q=${query}`;
if (chatBody.dataset.conversationId) {
webSocketUrl += `?conversation_id=${chatBody.dataset.conversationId}`;
webSocketUrl += (!!region && !!city && !!countryName) && !!timezone ? `&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : '';
websocket = new WebSocket(webSocketUrl);
websocket.onmessage = function(event) {
sseUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
sseUrl += (!!region && !!city && !!countryName) && !!timezone ? `&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : '';
function handleChatResponse(event) {
// Get the last element in the chat-body
let chunk = event.data;
if (chunk == "start_llm_response") {
console.log("Started streaming", new Date());
} else if (chunk == "end_llm_response") {
console.log("Stopped streaming", new Date());
try {
if (chunk.includes("application/json"))
chunk = JSON.parse(chunk);
} catch (error) {
// If the chunk is not a JSON object, continue.
}
// Automatically respond with voice if the subscribed user has sent voice message
if (websocketState.isVoice && "{{ is_active }}" == "True")
textToSpeech(websocketState.rawResponse);
// Append any references after all the data has been streamed
finalizeChatBodyResponse(websocketState.references, websocketState.newResponseTextEl);
const liveQuery = websocketState.rawQuery;
// Reset variables
websocketState = {
newResponseTextEl: null,
newResponseEl: null,
loadingEllipsis: null,
references: {},
rawResponse: "",
rawQuery: liveQuery,
isVoice: false,
}
} else {
const contentType = chunk["content-type"]
if (contentType === "application/json") {
// Handle JSON response
try {
if (chunk.includes("application/json"))
{
chunk = JSON.parse(chunk);
if (chunk.image || chunk.detail) {
({rawResponse, references } = handleImageResponse(chunk, chatMessageState.rawResponse));
chatMessageState.rawResponse = rawResponse;
chatMessageState.references = references;
} else {
rawResponse = chunk.response;
}
} catch (error) {
// If the chunk is not a JSON object, continue.
// If the chunk is not a JSON object, just display it as is
chatMessageState.rawResponse += chunk;
} finally {
addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references);
}
} else {
// Handle streamed response of type text/event-stream or text/plain
if (chunk && chunk.includes("### compiled references:")) {
({ rawResponse, references } = handleCompiledReferences(chatMessageState.newResponseTextEl, chunk, chatMessageState.references, chatMessageState.rawResponse));
chatMessageState.rawResponse = rawResponse;
chatMessageState.references = references;
} else {
// If the chunk is not a JSON object, just display it as is
chatMessageState.rawResponse += chunk;
if (chatMessageState.newResponseTextEl) {
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
}
}
const contentType = chunk["content-type"]
if (contentType === "application/json") {
// Handle JSON response
try {
if (chunk.image || chunk.detail) {
({rawResponse, references } = handleImageResponse(chunk, websocketState.rawResponse));
websocketState.rawResponse = rawResponse;
websocketState.references = references;
} else if (chunk.type == "status") {
handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, null, false);
} else if (chunk.type == "rate_limit") {
handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, websocketState.loadingEllipsis, true);
} else {
rawResponse = chunk.response;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
websocketState.rawResponse += chunk;
} finally {
if (chunk.type != "status" && chunk.type != "rate_limit") {
addMessageToChatBody(websocketState.rawResponse, websocketState.newResponseTextEl, websocketState.references);
}
}
} else {
// Handle streamed response of type text/event-stream or text/plain
if (chunk && chunk.includes("### compiled references:")) {
({ rawResponse, references } = handleCompiledReferences(websocketState.newResponseTextEl, chunk, websocketState.references, websocketState.rawResponse));
websocketState.rawResponse = rawResponse;
websocketState.references = references;
} else {
// If the chunk is not a JSON object, just display it as is
websocketState.rawResponse += chunk;
if (websocketState.newResponseTextEl) {
handleStreamResponse(websocketState.newResponseTextEl, websocketState.rawResponse, websocketState.rawQuery, websocketState.loadingEllipsis);
}
}
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
};
}
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
};
}
};
websocket.onclose = function(event) {
websocket = null;
console.log("WebSocket is closed now.");
let setupWebSocketButton = document.createElement("button");
setupWebSocketButton.textContent = "Reconnect to Server";
setupWebSocketButton.onclick = setupWebSocket;
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "red";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.innerHTML = "";
statusDotText.style.marginTop = "5px";
statusDotText.appendChild(setupWebSocketButton);
}
websocket.onerror = function(event) {
console.log("WebSocket error observed:", event);
}
websocket.onopen = function(event) {
console.log("WebSocket is open now.")
sseConnection = new EventSource(sseUrl);
sseConnection.onmessage = handleChatResponse;
sseConnection.addEventListener("complete_llm_response", handleChatResponse);
sseConnection.addEventListener("status", (event) => {
console.log(`${event.data}`);
handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, null, false);
});
sseConnection.addEventListener("rate_limit", (event) => {
handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, true);
});
sseConnection.addEventListener("start_llm_response", (event) => {
console.log("Started streaming", new Date());
});
sseConnection.addEventListener("end_llm_response", (event) => {
sseConnection.close();
console.log("Stopped streaming", new Date());
// Automatically respond with voice if the subscribed user has sent voice message
if (chatMessageState.isVoice && "{{ is_active }}" == "True")
textToSpeech(chatMessageState.rawResponse);
// Append any references after all the data has been streamed
finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl);
const liveQuery = chatMessageState.rawQuery;
// Reset variables
chatMessageState = {
newResponseTextEl: null,
newResponseEl: null,
loadingEllipsis: null,
references: {},
rawResponse: "",
rawQuery: liveQuery,
}
// Reset status icon
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "green";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.textContent = "Connected to Server";
statusDotText.textContent = "Ready";
statusDotText.style.marginTop = "5px";
});
sseConnection.onclose = function(event) {
sseConnection = null;
console.debug("SSE is closed now.");
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "green";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.textContent = "Ready";
statusDotText.style.marginTop = "5px";
}
sseConnection.onerror = function(event) {
console.log("SSE error observed:", event);
sseConnection.close();
sseConnection = null;
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "red";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.textContent = "Server Error";
if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) {
chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
}
chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev"
}
sseConnection.onopen = function(event) {
console.debug("SSE is open now.")
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "orange";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.textContent = "Processing";
}
}
function sendMessageViaWebSocket(isVoice=false) {
function sendMessageViaSSE(isVoice=false) {
let chatBody = document.getElementById("chat-body");
var query = document.getElementById("chat-input").value.trim();
@ -1242,11 +1253,11 @@ To get started, just start typing below. You can also type / to see a list of co
chatInput.classList.remove("option-enabled");
// Call specified Khoj API
websocket.send(query);
sendSSEMessage(query);
let rawResponse = "";
let references = {};
websocketState = {
chatMessageState = {
newResponseTextEl,
newResponseEl,
loadingEllipsis,
@ -1265,7 +1276,7 @@ To get started, just start typing below. You can also type / to see a list of co
let chatHistoryUrl = `/api/chat/history?client=web`;
if (chatBody.dataset.conversationId) {
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
setupWebSocket();
initializeSSE();
loadFileFiltersFromConversation();
}
@ -1305,7 +1316,7 @@ To get started, just start typing below. You can also type / to see a list of co
let chatBody = document.getElementById("chat-body");
chatBody.dataset.conversationId = response.conversation_id;
loadFileFiltersFromConversation();
setupWebSocket();
initializeSSE();
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
let agentMetadata = response.agent;

View file

@ -56,7 +56,8 @@ async def search_online(
query += " ".join(custom_filters)
if not is_internet_connected():
logger.warn("Cannot search online as not connected to internet")
return {}
yield {}
return
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(query, conversation_history, location)
@ -66,7 +67,8 @@ async def search_online(
logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
if send_status_func:
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
await send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}")
async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"):
yield {"status": event}
with timer(f"Internet searches for {list(subqueries)} took", logger):
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
@ -89,7 +91,8 @@ async def search_online(
logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}")
if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}")
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
yield {"status": event}
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
results = await asyncio.gather(*tasks)
@ -98,7 +101,7 @@ async def search_online(
if webpage_extract is not None:
response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract}
return response_dict
yield response_dict
async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
@ -127,13 +130,15 @@ async def read_webpages(
"Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read")
if send_status_func:
await send_status_func(f"**🧐 Inferring web pages to read**")
async for event in send_status_func(f"**🧐 Inferring web pages to read**"):
yield {"status": event}
urls = await infer_webpage_urls(query, conversation_history, location)
logger.info(f"Reading web pages at: {urls}")
if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(urls))
await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}")
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
yield {"status": event}
tasks = [read_webpage_and_extract_content(query, url) for url in urls]
results = await asyncio.gather(*tasks)
@ -141,7 +146,7 @@ async def read_webpages(
response[query]["webpages"] = [
{"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None
]
return response
yield response
async def read_webpage_and_extract_content(

View file

@ -6,7 +6,6 @@ import os
import threading
import time
import uuid
from random import random
from typing import Any, Callable, List, Optional, Union
import cron_descriptor
@ -298,11 +297,13 @@ async def extract_references_and_questions(
not ConversationCommand.Notes in conversation_commands
and not ConversationCommand.Default in conversation_commands
):
return compiled_references, inferred_queries, q
yield compiled_references, inferred_queries, q
return
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
return compiled_references, inferred_queries, q
yield compiled_references, inferred_queries, q
return
# Extract filter terms from user message
defiltered_query = q
@ -313,7 +314,8 @@ async def extract_references_and_questions(
if not conversation:
logger.error(f"Conversation with id {conversation_id} not found.")
return compiled_references, inferred_queries, defiltered_query
yield compiled_references, inferred_queries, defiltered_query
return
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
using_offline_chat = False
@ -372,7 +374,8 @@ async def extract_references_and_questions(
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
if send_status_func:
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
await send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}")
async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"):
yield {"status": event}
for query in inferred_queries:
n_items = min(n, 3) if using_offline_chat else n
search_results.extend(
@ -391,7 +394,7 @@ async def extract_references_and_questions(
{"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results
]
return compiled_references, inferred_queries, defiltered_query
yield compiled_references, inferred_queries, defiltered_query
@api.get("/health", response_class=Response)

View file

@ -1,17 +1,18 @@
import asyncio
import json
import logging
import math
from datetime import datetime
from functools import partial
from typing import Any, Dict, List, Optional
from urllib.parse import unquote
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse
from sse_starlette import EventSourceResponse
from starlette.authentication import requires
from starlette.websockets import WebSocketDisconnect
from websockets import ConnectionClosedOK
from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import (
@ -526,380 +527,441 @@ async def set_conversation_title(
)
@api_chat.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
@api_chat.get("/stream")
async def stream_chat(
request: Request,
q: str,
conversation_id: int,
city: Optional[str] = None,
region: Optional[str] = None,
country: Optional[str] = None,
timezone: Optional[str] = None,
):
connection_alive = True
async def event_generator(q: str):
connection_alive = True
async def send_status_update(message: str):
nonlocal connection_alive
if not connection_alive:
return
status_packet = {
"type": "status",
"message": message,
"content-type": "application/json",
}
try:
await websocket.send_text(json.dumps(status_packet))
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
async def send_complete_llm_response(llm_response: str):
nonlocal connection_alive
if not connection_alive:
return
try:
await websocket.send_text("start_llm_response")
await websocket.send_text(llm_response)
await websocket.send_text("end_llm_response")
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
async def send_message(message: str):
nonlocal connection_alive
if not connection_alive:
return
try:
await websocket.send_text(message)
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
async def send_rate_limit_message(message: str):
nonlocal connection_alive
if not connection_alive:
return
status_packet = {
"type": "rate_limit",
"message": message,
"content-type": "application/json",
}
try:
await websocket.send_text(json.dumps(status_packet))
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
user: KhojUser = websocket.user.object
conversation = await ConversationAdapters.aget_conversation_by_user(
user, client_application=websocket.user.client_app, conversation_id=conversation_id
)
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
await is_ready_to_chat(user)
user_name = await aget_user_name(user)
location = None
if city or region or country:
location = LocationData(city=city, region=region, country=country)
await websocket.accept()
while connection_alive:
try:
if conversation:
await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"])
q = await websocket.receive_text()
# Refresh these because the connection to the database might have been closed
await conversation.arefresh_from_db()
except WebSocketDisconnect:
logger.debug(f"User {user} disconnected web socket")
break
try:
await sync_to_async(hourly_limiter)(websocket)
await sync_to_async(daily_limiter)(websocket)
except HTTPException as e:
await send_rate_limit_message(e.detail)
break
if is_query_empty(q):
await send_message("start_llm_response")
await send_message(
"It seems like your query is incomplete. Could you please provide more details or specify what you need help with?"
)
await send_message("end_llm_response")
continue
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_commands = [get_conversation_command(query=q, any_references=True)]
await send_status_update(f"**👀 Understanding Query**: {q}")
meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
await send_status_update(f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}")
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
await send_status_update(f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}")
if mode not in conversation_commands:
conversation_commands.append(mode)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if
if (
# summarization intent was inferred
ConversationCommand.Summarize in conversation_commands
# and not triggered via slash command
and not used_slash_summarize
# but we can't actually summarize
and len(file_filters) != 1
):
conversation_commands.remove(ConversationCommand.Summarize)
elif ConversationCommand.Summarize in conversation_commands:
response_log = ""
if len(file_filters) == 0:
response_log = "No files selected for summarization. Please add files using the section on the left."
await send_complete_llm_response(response_log)
elif len(file_filters) > 1:
response_log = "Only one file can be selected for summarization."
await send_complete_llm_response(response_log)
else:
try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
await send_complete_llm_response(response_log)
continue
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
await send_status_update(f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}")
response = await extract_relevant_summary(q, contextual_data)
response_log = str(response)
await send_complete_llm_response(response_log)
except Exception as e:
response_log = "Error summarizing file."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
await send_complete_llm_response(response_log)
await sync_to_async(save_to_conversation_log)(
q,
response_log,
user,
meta_log,
user_message_time,
intent_type="summarize",
client_application=websocket.user.client_app,
conversation_id=conversation_id,
)
update_telemetry_state(
request=websocket,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
continue
custom_filters = []
if conversation_commands == [ConversationCommand.Help]:
if not q:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
await send_complete_llm_response(formatted_help)
continue
# Adding specification to search online specifically on khoj.dev pages.
custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online)
if ConversationCommand.Automation in conversation_commands:
async def send_event(event_type: str, data: str):
nonlocal connection_alive
if not connection_alive or await request.is_disconnected():
return
try:
automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, websocket.url, meta_log
)
if event_type == "message":
yield data
else:
yield {"event": event_type, "data": data, "retry": 15000}
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
await send_complete_llm_response(
f"Unable to create automation. Ensure the automation doesn't already exist."
)
continue
connection_alive = False
logger.info(f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}")
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
await sync_to_async(save_to_conversation_log)(
q,
llm_response,
user,
meta_log,
user_message_time,
intent_type="automation",
client_application=websocket.user.client_app,
conversation_id=conversation_id,
inferred_queries=[query_to_run],
automation_id=automation.id,
)
common = CommonQueryParamsClass(
client=websocket.user.client_app,
user_agent=websocket.headers.get("user-agent"),
host=websocket.headers.get("host"),
)
update_telemetry_state(
request=websocket,
telemetry_type="api",
api="chat",
**common.__dict__,
)
await send_complete_llm_response(llm_response)
continue
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
websocket, meta_log, q, 7, 0.18, conversation_id, conversation_commands, location, send_status_update
user: KhojUser = request.user.object
conversation = await ConversationAdapters.aget_conversation_by_user(
user, client_application=request.user.client_app, conversation_id=conversation_id
)
if compiled_references:
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
await send_status_update(f"**📜 Found Relevant Notes**: {headings}")
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
online_results: Dict = dict()
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
await send_complete_llm_response(f"{no_entries_found.format()}")
continue
await is_ready_to_chat(user)
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
user_name = await aget_user_name(user)
if ConversationCommand.Online in conversation_commands:
location = None
if city or region or country:
location = LocationData(city=city, region=region, country=country)
while connection_alive:
try:
online_results = await search_online(
defiltered_query, meta_log, location, send_status_update, custom_filters
)
except ValueError as e:
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
await send_complete_llm_response(
f"Error searching online: {e}. Attempting to respond without online results"
)
continue
if conversation:
await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"])
if ConversationCommand.Webpage in conversation_commands:
try:
direct_web_pages = await read_webpages(defiltered_query, meta_log, location, send_status_update)
webpages = []
for query in direct_web_pages:
if online_results.get(query):
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
else:
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
# Refresh these because the connection to the database might have been closed
await conversation.arefresh_from_db()
for webpage in direct_web_pages[query]["webpages"]:
webpages.append(webpage["link"])
await send_status_update(f"**📚 Read web pages**: {webpages}")
except ValueError as e:
logger.warning(
f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
)
if ConversationCommand.Image in conversation_commands:
update_telemetry_state(
request=websocket,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
image, status_code, improved_image_prompt, intent_type = await text_to_image(
q,
user,
meta_log,
location_data=location,
references=compiled_references,
online_results=online_results,
send_status_func=send_status_update,
)
if image is None or status_code != 200:
content_obj = {
"image": image,
"intentType": intent_type,
"detail": improved_image_prompt,
"content-type": "application/json",
}
await send_complete_llm_response(json.dumps(content_obj))
continue
await sync_to_async(save_to_conversation_log)(
q,
image,
user,
meta_log,
user_message_time,
intent_type=intent_type,
inferred_queries=[improved_image_prompt],
client_application=websocket.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
)
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore
await send_complete_llm_response(json.dumps(content_obj))
continue
await send_status_update(f"**💭 Generating a well-informed response**")
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
conversation,
compiled_references,
online_results,
inferred_queries,
conversation_commands,
user,
websocket.user.client_app,
conversation_id,
location,
user_name,
)
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
update_telemetry_state(
request=websocket,
telemetry_type="api",
api="chat",
metadata=chat_metadata,
)
iterator = AsyncIteratorWrapper(llm_response)
await send_message("start_llm_response")
async for item in iterator:
if item is None:
break
if connection_alive:
try:
await send_message(f"{item}")
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
await sync_to_async(hourly_limiter)(request)
await sync_to_async(daily_limiter)(request)
except HTTPException as e:
async for result in send_event("rate_limit", e.detail):
yield result
break
await send_message("end_llm_response")
if is_query_empty(q):
async for event in send_event("start_llm_response", ""):
yield event
async for event in send_event(
"message",
"It seems like your query is incomplete. Could you please provide more details or specify what you need help with?",
):
yield event
async for event in send_event("end_llm_response", ""):
yield event
return
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_commands = [get_conversation_command(query=q, any_references=True)]
async for result in send_event("status", f"**👀 Understanding Query**: {q}"):
yield result
meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event(
"status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
conversation_commands.append(mode)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if
if (
# summarization intent was inferred
ConversationCommand.Summarize in conversation_commands
# and not triggered via slash command
and not used_slash_summarize
# but we can't actually summarize
and len(file_filters) != 1
):
conversation_commands.remove(ConversationCommand.Summarize)
elif ConversationCommand.Summarize in conversation_commands:
response_log = ""
if len(file_filters) == 0:
response_log = (
"No files selected for summarization. Please add files using the section on the left."
)
async for result in send_event("complete_llm_response", response_log):
yield result
async for event in send_event("end_llm_response", ""):
yield event
elif len(file_filters) > 1:
response_log = "Only one file can be selected for summarization."
async for result in send_event("complete_llm_response", response_log):
yield result
async for event in send_event("end_llm_response", ""):
yield event
else:
try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
async for result in send_event("complete_llm_response", response_log):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
async for result in send_event(
"status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}"
):
yield result
response = await extract_relevant_summary(q, contextual_data)
response_log = str(response)
async for result in send_event("complete_llm_response", response_log):
yield result
async for event in send_event("end_llm_response", ""):
yield event
except Exception as e:
response_log = "Error summarizing file."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
async for result in send_event("complete_llm_response", response_log):
yield result
async for event in send_event("end_llm_response", ""):
yield event
await sync_to_async(save_to_conversation_log)(
q,
response_log,
user,
meta_log,
user_message_time,
intent_type="summarize",
client_application=request.user.client_app,
conversation_id=conversation_id,
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
return
custom_filters = []
if conversation_commands == [ConversationCommand.Help]:
if not q:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(
model=model_type, version=state.khoj_version, device=get_device()
)
async for result in send_event("complete_llm_response", formatted_help):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online)
if ConversationCommand.Automation in conversation_commands:
try:
automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, request.url, meta_log
)
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
async for result in send_event("complete_llm_response", error_message):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
await sync_to_async(save_to_conversation_log)(
q,
llm_response,
user,
meta_log,
user_message_time,
intent_type="automation",
client_application=request.user.client_app,
conversation_id=conversation_id,
inferred_queries=[query_to_run],
automation_id=automation.id,
)
common = CommonQueryParamsClass(
client=request.user.client_app,
user_agent=request.headers.get("user-agent"),
host=request.headers.get("host"),
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
**common.__dict__,
)
async for result in send_event("complete_llm_response", llm_response):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
compiled_references, inferred_queries, defiltered_query = [], [], None
async for result in extract_references_and_questions(
request,
meta_log,
q,
7,
0.18,
conversation_id,
conversation_commands,
location,
partial(send_event, "status"),
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(
set([c.get("compiled", c).split("\n")[0] for c in compiled_references])
)
async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"):
yield result
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(
user
):
async for result in send_event("complete_llm_response", f"{no_entries_found.format()}"):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands:
try:
async for result in search_online(
defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
online_results = result
except ValueError as e:
error_message = f"Error searching online: {e}. Attempting to respond without online results"
logger.warning(error_message)
async for result in send_event("complete_llm_response", error_message):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
if ConversationCommand.Webpage in conversation_commands:
try:
async for result in read_webpages(
defiltered_query, meta_log, location, partial(send_event, "status")
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
direct_web_pages = result
webpages = []
for query in direct_web_pages:
if online_results.get(query):
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
else:
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
for webpage in direct_web_pages[query]["webpages"]:
webpages.append(webpage["link"])
async for result in send_event("status", f"**📚 Read web pages**: {webpages}"):
yield result
except ValueError as e:
logger.warning(
f"Error directly reading webpages: {e}. Attempting to respond without online results",
exc_info=True,
)
if ConversationCommand.Image in conversation_commands:
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
async for result in text_to_image(
q,
user,
meta_log,
location_data=location,
references=compiled_references,
online_results=online_results,
send_status_func=partial(send_event, "status"),
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
image, status_code, improved_image_prompt, intent_type = result
if image is None or status_code != 200:
content_obj = {
"image": image,
"intentType": intent_type,
"detail": improved_image_prompt,
"content-type": "application/json",
}
async for result in send_event("complete_llm_response", json.dumps(content_obj)):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
await sync_to_async(save_to_conversation_log)(
q,
image,
user,
meta_log,
user_message_time,
intent_type=intent_type,
inferred_queries=[improved_image_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
)
content_obj = {
"image": image,
"intentType": intent_type,
"inferredQueries": [improved_image_prompt],
"context": compiled_references,
"content-type": "application/json",
"online_results": online_results,
}
async for result in send_event("complete_llm_response", json.dumps(content_obj)):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
async for result in send_event("status", f"**💭 Generating a well-informed response**"):
yield result
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
conversation,
compiled_references,
online_results,
inferred_queries,
conversation_commands,
user,
request.user.client_app,
conversation_id,
location,
user_name,
)
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata=chat_metadata,
)
iterator = AsyncIteratorWrapper(llm_response)
async for result in send_event("start_llm_response", ""):
yield result
async for item in iterator:
if item is None:
break
if connection_alive:
try:
async for result in send_event("message", f"{item}"):
yield result
except Exception as e:
connection_alive = False
logger.info(
f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}"
)
async for result in send_event("end_llm_response", ""):
yield result
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in SSE endpoint: {e}", exc_info=True)
break
return EventSourceResponse(event_generator(q))
@api_chat.get("", response_class=Response)

View file

@ -755,7 +755,7 @@ async def text_to_image(
references: List[Dict[str, Any]],
online_results: Dict[str, Any],
send_status_func: Optional[Callable] = None,
) -> Tuple[Optional[str], int, Optional[str], str]:
):
status_code = 200
image = None
response = None
@ -767,7 +767,8 @@ async def text_to_image(
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
return image_url or image, status_code, message, intent_type.value
yield image_url or image, status_code, message, intent_type.value
return
text2image_model = text_to_image_config.model_name
chat_history = ""
@ -781,7 +782,8 @@ async def text_to_image(
with timer("Improve the original user query", logger):
if send_status_func:
await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"):
yield {"status": event}
improved_image_prompt = await generate_better_image_prompt(
message,
chat_history,
@ -792,7 +794,8 @@ async def text_to_image(
)
if send_status_func:
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"):
yield {"status": event}
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
with timer("Generate image with OpenAI", logger):
@ -817,12 +820,14 @@ async def text_to_image(
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
return image_url or image, status_code, message, intent_type.value
yield image_url or image, status_code, message, intent_type.value
return
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value
yield image_url or image, status_code, message, intent_type.value
return
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
with timer("Generate image with Stability AI", logger):
@ -844,7 +849,8 @@ async def text_to_image(
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with Stability AI error: {e}"
status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value
yield image_url or image, status_code, message, intent_type.value
return
with timer("Convert image to webp", logger):
# Convert png to webp for faster loading
@ -864,7 +870,7 @@ async def text_to_image(
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8")
return image_url or image, status_code, improved_image_prompt, intent_type.value
yield image_url or image, status_code, improved_image_prompt, intent_type.value
class ApiUserRateLimiter: