mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Convert Websocket into Server Side Event (SSE) API endpoint
- Convert functions in SSE API path into async generators using yields - Validate image generation, online, notes lookup and general paths of chat request are handled fine by the web client and server API
This commit is contained in:
parent
e694c82343
commit
91fe41106e
6 changed files with 577 additions and 489 deletions
|
@ -40,6 +40,7 @@ dependencies = [
|
|||
"dateparser >= 1.1.1",
|
||||
"defusedxml == 0.7.1",
|
||||
"fastapi >= 0.104.1",
|
||||
"sse-starlette ~= 2.1.0",
|
||||
"python-multipart >= 0.0.7",
|
||||
"jinja2 == 3.1.4",
|
||||
"openai >= 1.0.0",
|
||||
|
|
|
@ -74,14 +74,14 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
}, 1000);
|
||||
});
|
||||
}
|
||||
var websocket = null;
|
||||
var sseConnection = null;
|
||||
let region = null;
|
||||
let city = null;
|
||||
let countryName = null;
|
||||
let timezone = null;
|
||||
let waitingForLocation = true;
|
||||
|
||||
let websocketState = {
|
||||
let chatMessageState = {
|
||||
newResponseTextEl: null,
|
||||
newResponseEl: null,
|
||||
loadingEllipsis: null,
|
||||
|
@ -105,7 +105,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
.finally(() => {
|
||||
console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone);
|
||||
waitingForLocation = false;
|
||||
setupWebSocket();
|
||||
initializeSSE();
|
||||
});
|
||||
|
||||
function formatDate(date) {
|
||||
|
@ -599,10 +599,8 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
}
|
||||
|
||||
async function chat(isVoice=false) {
|
||||
if (websocket) {
|
||||
sendMessageViaWebSocket(isVoice);
|
||||
return;
|
||||
}
|
||||
sendMessageViaSSE(isVoice);
|
||||
return;
|
||||
|
||||
let query = document.getElementById("chat-input").value.trim();
|
||||
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
|
||||
|
@ -1069,17 +1067,13 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
|
||||
window.onload = loadChat;
|
||||
|
||||
function setupWebSocket(isVoice=false) {
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
let wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
let webSocketUrl = `${wsProtocol}//${window.location.host}/api/chat/ws`;
|
||||
|
||||
function initializeSSE(isVoice=false) {
|
||||
if (waitingForLocation) {
|
||||
console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available.");
|
||||
return;
|
||||
}
|
||||
|
||||
websocketState = {
|
||||
chatMessageState = {
|
||||
newResponseTextEl: null,
|
||||
newResponseEl: null,
|
||||
loadingEllipsis: null,
|
||||
|
@ -1088,121 +1082,138 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
rawQuery: "",
|
||||
isVoice: isVoice,
|
||||
}
|
||||
}
|
||||
|
||||
function sendSSEMessage(query) {
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
let sseProtocol = window.location.protocol;
|
||||
let sseUrl = `/api/chat/stream?q=${query}`;
|
||||
|
||||
if (chatBody.dataset.conversationId) {
|
||||
webSocketUrl += `?conversation_id=${chatBody.dataset.conversationId}`;
|
||||
webSocketUrl += (!!region && !!city && !!countryName) && !!timezone ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : '';
|
||||
|
||||
websocket = new WebSocket(webSocketUrl);
|
||||
websocket.onmessage = function(event) {
|
||||
sseUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
|
||||
sseUrl += (!!region && !!city && !!countryName) && !!timezone ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : '';
|
||||
|
||||
function handleChatResponse(event) {
|
||||
// Get the last element in the chat-body
|
||||
let chunk = event.data;
|
||||
if (chunk == "start_llm_response") {
|
||||
console.log("Started streaming", new Date());
|
||||
} else if (chunk == "end_llm_response") {
|
||||
console.log("Stopped streaming", new Date());
|
||||
try {
|
||||
if (chunk.includes("application/json"))
|
||||
chunk = JSON.parse(chunk);
|
||||
} catch (error) {
|
||||
// If the chunk is not a JSON object, continue.
|
||||
}
|
||||
|
||||
// Automatically respond with voice if the subscribed user has sent voice message
|
||||
if (websocketState.isVoice && "{{ is_active }}" == "True")
|
||||
textToSpeech(websocketState.rawResponse);
|
||||
|
||||
// Append any references after all the data has been streamed
|
||||
finalizeChatBodyResponse(websocketState.references, websocketState.newResponseTextEl);
|
||||
|
||||
const liveQuery = websocketState.rawQuery;
|
||||
// Reset variables
|
||||
websocketState = {
|
||||
newResponseTextEl: null,
|
||||
newResponseEl: null,
|
||||
loadingEllipsis: null,
|
||||
references: {},
|
||||
rawResponse: "",
|
||||
rawQuery: liveQuery,
|
||||
isVoice: false,
|
||||
}
|
||||
} else {
|
||||
const contentType = chunk["content-type"]
|
||||
if (contentType === "application/json") {
|
||||
// Handle JSON response
|
||||
try {
|
||||
if (chunk.includes("application/json"))
|
||||
{
|
||||
chunk = JSON.parse(chunk);
|
||||
if (chunk.image || chunk.detail) {
|
||||
({rawResponse, references } = handleImageResponse(chunk, chatMessageState.rawResponse));
|
||||
chatMessageState.rawResponse = rawResponse;
|
||||
chatMessageState.references = references;
|
||||
} else {
|
||||
rawResponse = chunk.response;
|
||||
}
|
||||
} catch (error) {
|
||||
// If the chunk is not a JSON object, continue.
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
chatMessageState.rawResponse += chunk;
|
||||
} finally {
|
||||
addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references);
|
||||
}
|
||||
} else {
|
||||
// Handle streamed response of type text/event-stream or text/plain
|
||||
if (chunk && chunk.includes("### compiled references:")) {
|
||||
({ rawResponse, references } = handleCompiledReferences(chatMessageState.newResponseTextEl, chunk, chatMessageState.references, chatMessageState.rawResponse));
|
||||
chatMessageState.rawResponse = rawResponse;
|
||||
chatMessageState.references = references;
|
||||
} else {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
chatMessageState.rawResponse += chunk;
|
||||
if (chatMessageState.newResponseTextEl) {
|
||||
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
|
||||
}
|
||||
}
|
||||
|
||||
const contentType = chunk["content-type"]
|
||||
|
||||
if (contentType === "application/json") {
|
||||
// Handle JSON response
|
||||
try {
|
||||
if (chunk.image || chunk.detail) {
|
||||
({rawResponse, references } = handleImageResponse(chunk, websocketState.rawResponse));
|
||||
websocketState.rawResponse = rawResponse;
|
||||
websocketState.references = references;
|
||||
} else if (chunk.type == "status") {
|
||||
handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, null, false);
|
||||
} else if (chunk.type == "rate_limit") {
|
||||
handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, websocketState.loadingEllipsis, true);
|
||||
} else {
|
||||
rawResponse = chunk.response;
|
||||
}
|
||||
} catch (error) {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
websocketState.rawResponse += chunk;
|
||||
} finally {
|
||||
if (chunk.type != "status" && chunk.type != "rate_limit") {
|
||||
addMessageToChatBody(websocketState.rawResponse, websocketState.newResponseTextEl, websocketState.references);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
// Handle streamed response of type text/event-stream or text/plain
|
||||
if (chunk && chunk.includes("### compiled references:")) {
|
||||
({ rawResponse, references } = handleCompiledReferences(websocketState.newResponseTextEl, chunk, websocketState.references, websocketState.rawResponse));
|
||||
websocketState.rawResponse = rawResponse;
|
||||
websocketState.references = references;
|
||||
} else {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
websocketState.rawResponse += chunk;
|
||||
if (websocketState.newResponseTextEl) {
|
||||
handleStreamResponse(websocketState.newResponseTextEl, websocketState.rawResponse, websocketState.rawQuery, websocketState.loadingEllipsis);
|
||||
}
|
||||
}
|
||||
|
||||
// Scroll to bottom of chat window as chat response is streamed
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
};
|
||||
}
|
||||
// Scroll to bottom of chat window as chat response is streamed
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
};
|
||||
}
|
||||
};
|
||||
websocket.onclose = function(event) {
|
||||
websocket = null;
|
||||
console.log("WebSocket is closed now.");
|
||||
let setupWebSocketButton = document.createElement("button");
|
||||
setupWebSocketButton.textContent = "Reconnect to Server";
|
||||
setupWebSocketButton.onclick = setupWebSocket;
|
||||
let statusDotIcon = document.getElementById("connection-status-icon");
|
||||
statusDotIcon.style.backgroundColor = "red";
|
||||
let statusDotText = document.getElementById("connection-status-text");
|
||||
statusDotText.innerHTML = "";
|
||||
statusDotText.style.marginTop = "5px";
|
||||
statusDotText.appendChild(setupWebSocketButton);
|
||||
}
|
||||
websocket.onerror = function(event) {
|
||||
console.log("WebSocket error observed:", event);
|
||||
}
|
||||
|
||||
websocket.onopen = function(event) {
|
||||
console.log("WebSocket is open now.")
|
||||
sseConnection = new EventSource(sseUrl);
|
||||
sseConnection.onmessage = handleChatResponse;
|
||||
sseConnection.addEventListener("complete_llm_response", handleChatResponse);
|
||||
sseConnection.addEventListener("status", (event) => {
|
||||
console.log(`${event.data}`);
|
||||
handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, null, false);
|
||||
});
|
||||
sseConnection.addEventListener("rate_limit", (event) => {
|
||||
handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, true);
|
||||
});
|
||||
sseConnection.addEventListener("start_llm_response", (event) => {
|
||||
console.log("Started streaming", new Date());
|
||||
});
|
||||
sseConnection.addEventListener("end_llm_response", (event) => {
|
||||
sseConnection.close();
|
||||
console.log("Stopped streaming", new Date());
|
||||
|
||||
// Automatically respond with voice if the subscribed user has sent voice message
|
||||
if (chatMessageState.isVoice && "{{ is_active }}" == "True")
|
||||
textToSpeech(chatMessageState.rawResponse);
|
||||
|
||||
// Append any references after all the data has been streamed
|
||||
finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl);
|
||||
|
||||
const liveQuery = chatMessageState.rawQuery;
|
||||
// Reset variables
|
||||
chatMessageState = {
|
||||
newResponseTextEl: null,
|
||||
newResponseEl: null,
|
||||
loadingEllipsis: null,
|
||||
references: {},
|
||||
rawResponse: "",
|
||||
rawQuery: liveQuery,
|
||||
}
|
||||
|
||||
// Reset status icon
|
||||
let statusDotIcon = document.getElementById("connection-status-icon");
|
||||
statusDotIcon.style.backgroundColor = "green";
|
||||
let statusDotText = document.getElementById("connection-status-text");
|
||||
statusDotText.textContent = "Connected to Server";
|
||||
statusDotText.textContent = "Ready";
|
||||
statusDotText.style.marginTop = "5px";
|
||||
});
|
||||
sseConnection.onclose = function(event) {
|
||||
sseConnection = null;
|
||||
console.debug("SSE is closed now.");
|
||||
let statusDotIcon = document.getElementById("connection-status-icon");
|
||||
statusDotIcon.style.backgroundColor = "green";
|
||||
let statusDotText = document.getElementById("connection-status-text");
|
||||
statusDotText.textContent = "Ready";
|
||||
statusDotText.style.marginTop = "5px";
|
||||
}
|
||||
sseConnection.onerror = function(event) {
|
||||
console.log("SSE error observed:", event);
|
||||
sseConnection.close();
|
||||
sseConnection = null;
|
||||
let statusDotIcon = document.getElementById("connection-status-icon");
|
||||
statusDotIcon.style.backgroundColor = "red";
|
||||
let statusDotText = document.getElementById("connection-status-text");
|
||||
statusDotText.textContent = "Server Error";
|
||||
if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) {
|
||||
chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
|
||||
}
|
||||
chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev"
|
||||
}
|
||||
sseConnection.onopen = function(event) {
|
||||
console.debug("SSE is open now.")
|
||||
let statusDotIcon = document.getElementById("connection-status-icon");
|
||||
statusDotIcon.style.backgroundColor = "orange";
|
||||
let statusDotText = document.getElementById("connection-status-text");
|
||||
statusDotText.textContent = "Processing";
|
||||
}
|
||||
}
|
||||
|
||||
function sendMessageViaWebSocket(isVoice=false) {
|
||||
function sendMessageViaSSE(isVoice=false) {
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
|
||||
var query = document.getElementById("chat-input").value.trim();
|
||||
|
@ -1242,11 +1253,11 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
chatInput.classList.remove("option-enabled");
|
||||
|
||||
// Call specified Khoj API
|
||||
websocket.send(query);
|
||||
sendSSEMessage(query);
|
||||
let rawResponse = "";
|
||||
let references = {};
|
||||
|
||||
websocketState = {
|
||||
chatMessageState = {
|
||||
newResponseTextEl,
|
||||
newResponseEl,
|
||||
loadingEllipsis,
|
||||
|
@ -1265,7 +1276,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
let chatHistoryUrl = `/api/chat/history?client=web`;
|
||||
if (chatBody.dataset.conversationId) {
|
||||
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
|
||||
setupWebSocket();
|
||||
initializeSSE();
|
||||
loadFileFiltersFromConversation();
|
||||
}
|
||||
|
||||
|
@ -1305,7 +1316,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
let chatBody = document.getElementById("chat-body");
|
||||
chatBody.dataset.conversationId = response.conversation_id;
|
||||
loadFileFiltersFromConversation();
|
||||
setupWebSocket();
|
||||
initializeSSE();
|
||||
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
|
||||
|
||||
let agentMetadata = response.agent;
|
||||
|
|
|
@ -56,7 +56,8 @@ async def search_online(
|
|||
query += " ".join(custom_filters)
|
||||
if not is_internet_connected():
|
||||
logger.warn("Cannot search online as not connected to internet")
|
||||
return {}
|
||||
yield {}
|
||||
return
|
||||
|
||||
# Breakdown the query into subqueries to get the correct answer
|
||||
subqueries = await generate_online_subqueries(query, conversation_history, location)
|
||||
|
@ -66,7 +67,8 @@ async def search_online(
|
|||
logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
|
||||
if send_status_func:
|
||||
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
|
||||
await send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}")
|
||||
async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"):
|
||||
yield {"status": event}
|
||||
|
||||
with timer(f"Internet searches for {list(subqueries)} took", logger):
|
||||
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
|
||||
|
@ -89,7 +91,8 @@ async def search_online(
|
|||
logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}")
|
||||
if send_status_func:
|
||||
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
|
||||
await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}")
|
||||
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
|
||||
yield {"status": event}
|
||||
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
|
@ -98,7 +101,7 @@ async def search_online(
|
|||
if webpage_extract is not None:
|
||||
response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract}
|
||||
|
||||
return response_dict
|
||||
yield response_dict
|
||||
|
||||
|
||||
async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
|
||||
|
@ -127,13 +130,15 @@ async def read_webpages(
|
|||
"Infer web pages to read from the query and extract relevant information from them"
|
||||
logger.info(f"Inferring web pages to read")
|
||||
if send_status_func:
|
||||
await send_status_func(f"**🧐 Inferring web pages to read**")
|
||||
async for event in send_status_func(f"**🧐 Inferring web pages to read**"):
|
||||
yield {"status": event}
|
||||
urls = await infer_webpage_urls(query, conversation_history, location)
|
||||
|
||||
logger.info(f"Reading web pages at: {urls}")
|
||||
if send_status_func:
|
||||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||
await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}")
|
||||
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
|
||||
yield {"status": event}
|
||||
tasks = [read_webpage_and_extract_content(query, url) for url in urls]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
|
@ -141,7 +146,7 @@ async def read_webpages(
|
|||
response[query]["webpages"] = [
|
||||
{"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None
|
||||
]
|
||||
return response
|
||||
yield response
|
||||
|
||||
|
||||
async def read_webpage_and_extract_content(
|
||||
|
|
|
@ -6,7 +6,6 @@ import os
|
|||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from random import random
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import cron_descriptor
|
||||
|
@ -298,11 +297,13 @@ async def extract_references_and_questions(
|
|||
not ConversationCommand.Notes in conversation_commands
|
||||
and not ConversationCommand.Default in conversation_commands
|
||||
):
|
||||
return compiled_references, inferred_queries, q
|
||||
yield compiled_references, inferred_queries, q
|
||||
return
|
||||
|
||||
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
|
||||
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
|
||||
return compiled_references, inferred_queries, q
|
||||
yield compiled_references, inferred_queries, q
|
||||
return
|
||||
|
||||
# Extract filter terms from user message
|
||||
defiltered_query = q
|
||||
|
@ -313,7 +314,8 @@ async def extract_references_and_questions(
|
|||
|
||||
if not conversation:
|
||||
logger.error(f"Conversation with id {conversation_id} not found.")
|
||||
return compiled_references, inferred_queries, defiltered_query
|
||||
yield compiled_references, inferred_queries, defiltered_query
|
||||
return
|
||||
|
||||
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
|
||||
using_offline_chat = False
|
||||
|
@ -372,7 +374,8 @@ async def extract_references_and_questions(
|
|||
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
|
||||
if send_status_func:
|
||||
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
|
||||
await send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}")
|
||||
async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"):
|
||||
yield {"status": event}
|
||||
for query in inferred_queries:
|
||||
n_items = min(n, 3) if using_offline_chat else n
|
||||
search_results.extend(
|
||||
|
@ -391,7 +394,7 @@ async def extract_references_and_questions(
|
|||
{"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results
|
||||
]
|
||||
|
||||
return compiled_references, inferred_queries, defiltered_query
|
||||
yield compiled_references, inferred_queries, defiltered_query
|
||||
|
||||
|
||||
@api.get("/health", response_class=Response)
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import unquote
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.authentication import requires
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
from websockets import ConnectionClosedOK
|
||||
|
||||
from khoj.app.settings import ALLOWED_HOSTS
|
||||
from khoj.database.adapters import (
|
||||
|
@ -526,380 +527,441 @@ async def set_conversation_title(
|
|||
)
|
||||
|
||||
|
||||
@api_chat.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
@api_chat.get("/stream")
|
||||
async def stream_chat(
|
||||
request: Request,
|
||||
q: str,
|
||||
conversation_id: int,
|
||||
city: Optional[str] = None,
|
||||
region: Optional[str] = None,
|
||||
country: Optional[str] = None,
|
||||
timezone: Optional[str] = None,
|
||||
):
|
||||
connection_alive = True
|
||||
async def event_generator(q: str):
|
||||
connection_alive = True
|
||||
|
||||
async def send_status_update(message: str):
|
||||
nonlocal connection_alive
|
||||
if not connection_alive:
|
||||
return
|
||||
|
||||
status_packet = {
|
||||
"type": "status",
|
||||
"message": message,
|
||||
"content-type": "application/json",
|
||||
}
|
||||
try:
|
||||
await websocket.send_text(json.dumps(status_packet))
|
||||
except ConnectionClosedOK:
|
||||
connection_alive = False
|
||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||
|
||||
async def send_complete_llm_response(llm_response: str):
|
||||
nonlocal connection_alive
|
||||
if not connection_alive:
|
||||
return
|
||||
try:
|
||||
await websocket.send_text("start_llm_response")
|
||||
await websocket.send_text(llm_response)
|
||||
await websocket.send_text("end_llm_response")
|
||||
except ConnectionClosedOK:
|
||||
connection_alive = False
|
||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||
|
||||
async def send_message(message: str):
|
||||
nonlocal connection_alive
|
||||
if not connection_alive:
|
||||
return
|
||||
try:
|
||||
await websocket.send_text(message)
|
||||
except ConnectionClosedOK:
|
||||
connection_alive = False
|
||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||
|
||||
async def send_rate_limit_message(message: str):
|
||||
nonlocal connection_alive
|
||||
if not connection_alive:
|
||||
return
|
||||
|
||||
status_packet = {
|
||||
"type": "rate_limit",
|
||||
"message": message,
|
||||
"content-type": "application/json",
|
||||
}
|
||||
try:
|
||||
await websocket.send_text(json.dumps(status_packet))
|
||||
except ConnectionClosedOK:
|
||||
connection_alive = False
|
||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||
|
||||
user: KhojUser = websocket.user.object
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||
user, client_application=websocket.user.client_app, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
||||
|
||||
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||
|
||||
await is_ready_to_chat(user)
|
||||
|
||||
user_name = await aget_user_name(user)
|
||||
|
||||
location = None
|
||||
|
||||
if city or region or country:
|
||||
location = LocationData(city=city, region=region, country=country)
|
||||
|
||||
await websocket.accept()
|
||||
while connection_alive:
|
||||
try:
|
||||
if conversation:
|
||||
await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"])
|
||||
q = await websocket.receive_text()
|
||||
|
||||
# Refresh these because the connection to the database might have been closed
|
||||
await conversation.arefresh_from_db()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug(f"User {user} disconnected web socket")
|
||||
break
|
||||
|
||||
try:
|
||||
await sync_to_async(hourly_limiter)(websocket)
|
||||
await sync_to_async(daily_limiter)(websocket)
|
||||
except HTTPException as e:
|
||||
await send_rate_limit_message(e.detail)
|
||||
break
|
||||
|
||||
if is_query_empty(q):
|
||||
await send_message("start_llm_response")
|
||||
await send_message(
|
||||
"It seems like your query is incomplete. Could you please provide more details or specify what you need help with?"
|
||||
)
|
||||
await send_message("end_llm_response")
|
||||
continue
|
||||
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||
|
||||
await send_status_update(f"**👀 Understanding Query**: {q}")
|
||||
|
||||
meta_log = conversation.conversation_log
|
||||
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
||||
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
|
||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||
await send_status_update(f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}")
|
||||
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
|
||||
await send_status_update(f"**🧑🏾💻 Decided Response Mode:** {mode.value}")
|
||||
if mode not in conversation_commands:
|
||||
conversation_commands.append(mode)
|
||||
|
||||
for cmd in conversation_commands:
|
||||
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
|
||||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
|
||||
file_filters = conversation.file_filters if conversation else []
|
||||
# Skip trying to summarize if
|
||||
if (
|
||||
# summarization intent was inferred
|
||||
ConversationCommand.Summarize in conversation_commands
|
||||
# and not triggered via slash command
|
||||
and not used_slash_summarize
|
||||
# but we can't actually summarize
|
||||
and len(file_filters) != 1
|
||||
):
|
||||
conversation_commands.remove(ConversationCommand.Summarize)
|
||||
elif ConversationCommand.Summarize in conversation_commands:
|
||||
response_log = ""
|
||||
if len(file_filters) == 0:
|
||||
response_log = "No files selected for summarization. Please add files using the section on the left."
|
||||
await send_complete_llm_response(response_log)
|
||||
elif len(file_filters) > 1:
|
||||
response_log = "Only one file can be selected for summarization."
|
||||
await send_complete_llm_response(response_log)
|
||||
else:
|
||||
try:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||
if len(file_object) == 0:
|
||||
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
|
||||
await send_complete_llm_response(response_log)
|
||||
continue
|
||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||
if not q:
|
||||
q = "Create a general summary of the file"
|
||||
await send_status_update(f"**🧑🏾💻 Constructing Summary Using:** {file_object[0].file_name}")
|
||||
response = await extract_relevant_summary(q, contextual_data)
|
||||
response_log = str(response)
|
||||
await send_complete_llm_response(response_log)
|
||||
except Exception as e:
|
||||
response_log = "Error summarizing file."
|
||||
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
||||
await send_complete_llm_response(response_log)
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
response_log,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
intent_type="summarize",
|
||||
client_application=websocket.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
update_telemetry_state(
|
||||
request=websocket,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
)
|
||||
continue
|
||||
|
||||
custom_filters = []
|
||||
if conversation_commands == [ConversationCommand.Help]:
|
||||
if not q:
|
||||
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||
if conversation_config == None:
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
model_type = conversation_config.model_type
|
||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||
await send_complete_llm_response(formatted_help)
|
||||
continue
|
||||
# Adding specification to search online specifically on khoj.dev pages.
|
||||
custom_filters.append("site:khoj.dev")
|
||||
conversation_commands.append(ConversationCommand.Online)
|
||||
|
||||
if ConversationCommand.Automation in conversation_commands:
|
||||
async def send_event(event_type: str, data: str):
|
||||
nonlocal connection_alive
|
||||
if not connection_alive or await request.is_disconnected():
|
||||
return
|
||||
try:
|
||||
automation, crontime, query_to_run, subject = await create_automation(
|
||||
q, timezone, user, websocket.url, meta_log
|
||||
)
|
||||
if event_type == "message":
|
||||
yield data
|
||||
else:
|
||||
yield {"event": event_type, "data": data, "retry": 15000}
|
||||
except Exception as e:
|
||||
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||
await send_complete_llm_response(
|
||||
f"Unable to create automation. Ensure the automation doesn't already exist."
|
||||
)
|
||||
continue
|
||||
connection_alive = False
|
||||
logger.info(f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}")
|
||||
|
||||
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
llm_response,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
intent_type="automation",
|
||||
client_application=websocket.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
inferred_queries=[query_to_run],
|
||||
automation_id=automation.id,
|
||||
)
|
||||
common = CommonQueryParamsClass(
|
||||
client=websocket.user.client_app,
|
||||
user_agent=websocket.headers.get("user-agent"),
|
||||
host=websocket.headers.get("host"),
|
||||
)
|
||||
update_telemetry_state(
|
||||
request=websocket,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
**common.__dict__,
|
||||
)
|
||||
await send_complete_llm_response(llm_response)
|
||||
continue
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
websocket, meta_log, q, 7, 0.18, conversation_id, conversation_commands, location, send_status_update
|
||||
user: KhojUser = request.user.object
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||
user, client_application=request.user.client_app, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
if compiled_references:
|
||||
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
||||
await send_status_update(f"**📜 Found Relevant Notes**: {headings}")
|
||||
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
||||
|
||||
online_results: Dict = dict()
|
||||
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||
|
||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||
await send_complete_llm_response(f"{no_entries_found.format()}")
|
||||
continue
|
||||
await is_ready_to_chat(user)
|
||||
|
||||
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
||||
conversation_commands.remove(ConversationCommand.Notes)
|
||||
user_name = await aget_user_name(user)
|
||||
|
||||
if ConversationCommand.Online in conversation_commands:
|
||||
location = None
|
||||
|
||||
if city or region or country:
|
||||
location = LocationData(city=city, region=region, country=country)
|
||||
|
||||
while connection_alive:
|
||||
try:
|
||||
online_results = await search_online(
|
||||
defiltered_query, meta_log, location, send_status_update, custom_filters
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
|
||||
await send_complete_llm_response(
|
||||
f"Error searching online: {e}. Attempting to respond without online results"
|
||||
)
|
||||
continue
|
||||
if conversation:
|
||||
await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"])
|
||||
|
||||
if ConversationCommand.Webpage in conversation_commands:
|
||||
try:
|
||||
direct_web_pages = await read_webpages(defiltered_query, meta_log, location, send_status_update)
|
||||
webpages = []
|
||||
for query in direct_web_pages:
|
||||
if online_results.get(query):
|
||||
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
|
||||
else:
|
||||
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
|
||||
# Refresh these because the connection to the database might have been closed
|
||||
await conversation.arefresh_from_db()
|
||||
|
||||
for webpage in direct_web_pages[query]["webpages"]:
|
||||
webpages.append(webpage["link"])
|
||||
|
||||
await send_status_update(f"**📚 Read web pages**: {webpages}")
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
|
||||
)
|
||||
|
||||
if ConversationCommand.Image in conversation_commands:
|
||||
update_telemetry_state(
|
||||
request=websocket,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
)
|
||||
image, status_code, improved_image_prompt, intent_type = await text_to_image(
|
||||
q,
|
||||
user,
|
||||
meta_log,
|
||||
location_data=location,
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
send_status_func=send_status_update,
|
||||
)
|
||||
if image is None or status_code != 200:
|
||||
content_obj = {
|
||||
"image": image,
|
||||
"intentType": intent_type,
|
||||
"detail": improved_image_prompt,
|
||||
"content-type": "application/json",
|
||||
}
|
||||
await send_complete_llm_response(json.dumps(content_obj))
|
||||
continue
|
||||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
image,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
intent_type=intent_type,
|
||||
inferred_queries=[improved_image_prompt],
|
||||
client_application=websocket.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
)
|
||||
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore
|
||||
|
||||
await send_complete_llm_response(json.dumps(content_obj))
|
||||
continue
|
||||
|
||||
await send_status_update(f"**💭 Generating a well-informed response**")
|
||||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
conversation,
|
||||
compiled_references,
|
||||
online_results,
|
||||
inferred_queries,
|
||||
conversation_commands,
|
||||
user,
|
||||
websocket.user.client_app,
|
||||
conversation_id,
|
||||
location,
|
||||
user_name,
|
||||
)
|
||||
|
||||
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
|
||||
|
||||
update_telemetry_state(
|
||||
request=websocket,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata=chat_metadata,
|
||||
)
|
||||
iterator = AsyncIteratorWrapper(llm_response)
|
||||
|
||||
await send_message("start_llm_response")
|
||||
|
||||
async for item in iterator:
|
||||
if item is None:
|
||||
break
|
||||
if connection_alive:
|
||||
try:
|
||||
await send_message(f"{item}")
|
||||
except ConnectionClosedOK:
|
||||
connection_alive = False
|
||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||
await sync_to_async(hourly_limiter)(request)
|
||||
await sync_to_async(daily_limiter)(request)
|
||||
except HTTPException as e:
|
||||
async for result in send_event("rate_limit", e.detail):
|
||||
yield result
|
||||
break
|
||||
|
||||
await send_message("end_llm_response")
|
||||
if is_query_empty(q):
|
||||
async for event in send_event("start_llm_response", ""):
|
||||
yield event
|
||||
async for event in send_event(
|
||||
"message",
|
||||
"It seems like your query is incomplete. Could you please provide more details or specify what you need help with?",
|
||||
):
|
||||
yield event
|
||||
async for event in send_event("end_llm_response", ""):
|
||||
yield event
|
||||
return
|
||||
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||
|
||||
async for result in send_event("status", f"**👀 Understanding Query**: {q}"):
|
||||
yield result
|
||||
|
||||
meta_log = conversation.conversation_log
|
||||
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
||||
|
||||
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
|
||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||
async for result in send_event(
|
||||
"status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
|
||||
):
|
||||
yield result
|
||||
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
|
||||
async for result in send_event("status", f"**🧑🏾💻 Decided Response Mode:** {mode.value}"):
|
||||
yield result
|
||||
if mode not in conversation_commands:
|
||||
conversation_commands.append(mode)
|
||||
|
||||
for cmd in conversation_commands:
|
||||
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
||||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
|
||||
file_filters = conversation.file_filters if conversation else []
|
||||
# Skip trying to summarize if
|
||||
if (
|
||||
# summarization intent was inferred
|
||||
ConversationCommand.Summarize in conversation_commands
|
||||
# and not triggered via slash command
|
||||
and not used_slash_summarize
|
||||
# but we can't actually summarize
|
||||
and len(file_filters) != 1
|
||||
):
|
||||
conversation_commands.remove(ConversationCommand.Summarize)
|
||||
elif ConversationCommand.Summarize in conversation_commands:
|
||||
response_log = ""
|
||||
if len(file_filters) == 0:
|
||||
response_log = (
|
||||
"No files selected for summarization. Please add files using the section on the left."
|
||||
)
|
||||
async for result in send_event("complete_llm_response", response_log):
|
||||
yield result
|
||||
async for event in send_event("end_llm_response", ""):
|
||||
yield event
|
||||
elif len(file_filters) > 1:
|
||||
response_log = "Only one file can be selected for summarization."
|
||||
async for result in send_event("complete_llm_response", response_log):
|
||||
yield result
|
||||
async for event in send_event("end_llm_response", ""):
|
||||
yield event
|
||||
else:
|
||||
try:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||
if len(file_object) == 0:
|
||||
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
|
||||
async for result in send_event("complete_llm_response", response_log):
|
||||
yield result
|
||||
async for event in send_event("end_llm_response", ""):
|
||||
yield event
|
||||
return
|
||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||
if not q:
|
||||
q = "Create a general summary of the file"
|
||||
async for result in send_event(
|
||||
"status", f"**🧑🏾💻 Constructing Summary Using:** {file_object[0].file_name}"
|
||||
):
|
||||
yield result
|
||||
|
||||
response = await extract_relevant_summary(q, contextual_data)
|
||||
response_log = str(response)
|
||||
async for result in send_event("complete_llm_response", response_log):
|
||||
yield result
|
||||
async for event in send_event("end_llm_response", ""):
|
||||
yield event
|
||||
except Exception as e:
|
||||
response_log = "Error summarizing file."
|
||||
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
||||
async for result in send_event("complete_llm_response", response_log):
|
||||
yield result
|
||||
async for event in send_event("end_llm_response", ""):
|
||||
yield event
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
response_log,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
intent_type="summarize",
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
)
|
||||
return
|
||||
|
||||
custom_filters = []
|
||||
if conversation_commands == [ConversationCommand.Help]:
|
||||
if not q:
|
||||
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||
if conversation_config == None:
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
model_type = conversation_config.model_type
|
||||
formatted_help = help_message.format(
|
||||
model=model_type, version=state.khoj_version, device=get_device()
|
||||
)
|
||||
async for result in send_event("complete_llm_response", formatted_help):
|
||||
yield result
|
||||
async for event in send_event("end_llm_response", ""):
|
||||
yield event
|
||||
return
|
||||
custom_filters.append("site:khoj.dev")
|
||||
conversation_commands.append(ConversationCommand.Online)
|
||||
|
||||
if ConversationCommand.Automation in conversation_commands:
|
||||
try:
|
||||
automation, crontime, query_to_run, subject = await create_automation(
|
||||
q, timezone, user, request.url, meta_log
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
|
||||
async for result in send_event("complete_llm_response", error_message):
|
||||
yield result
|
||||
async for event in send_event("end_llm_response", ""):
|
||||
yield event
|
||||
return
|
||||
|
||||
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
llm_response,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
intent_type="automation",
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
inferred_queries=[query_to_run],
|
||||
automation_id=automation.id,
|
||||
)
|
||||
common = CommonQueryParamsClass(
|
||||
client=request.user.client_app,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
host=request.headers.get("host"),
|
||||
)
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
**common.__dict__,
|
||||
)
|
||||
async for result in send_event("complete_llm_response", llm_response):
|
||||
yield result
|
||||
async for event in send_event("end_llm_response", ""):
|
||||
yield event
|
||||
return
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||
async for result in extract_references_and_questions(
|
||||
request,
|
||||
meta_log,
|
||||
q,
|
||||
7,
|
||||
0.18,
|
||||
conversation_id,
|
||||
conversation_commands,
|
||||
location,
|
||||
partial(send_event, "status"),
|
||||
):
|
||||
if isinstance(result, dict) and "status" in result:
|
||||
yield result["status"]
|
||||
else:
|
||||
compiled_references.extend(result[0])
|
||||
inferred_queries.extend(result[1])
|
||||
defiltered_query = result[2]
|
||||
|
||||
if not is_none_or_empty(compiled_references):
|
||||
headings = "\n- " + "\n- ".join(
|
||||
set([c.get("compiled", c).split("\n")[0] for c in compiled_references])
|
||||
)
|
||||
async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"):
|
||||
yield result
|
||||
|
||||
online_results: Dict = dict()
|
||||
|
||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(
|
||||
user
|
||||
):
|
||||
async for result in send_event("complete_llm_response", f"{no_entries_found.format()}"):
|
||||
yield result
|
||||
async for event in send_event("end_llm_response", ""):
|
||||
yield event
|
||||
return
|
||||
|
||||
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
||||
conversation_commands.remove(ConversationCommand.Notes)
|
||||
|
||||
if ConversationCommand.Online in conversation_commands:
|
||||
try:
|
||||
async for result in search_online(
|
||||
defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters
|
||||
):
|
||||
if isinstance(result, dict) and "status" in result:
|
||||
yield result["status"]
|
||||
else:
|
||||
online_results = result
|
||||
except ValueError as e:
|
||||
error_message = f"Error searching online: {e}. Attempting to respond without online results"
|
||||
logger.warning(error_message)
|
||||
async for result in send_event("complete_llm_response", error_message):
|
||||
yield result
|
||||
async for event in send_event("end_llm_response", ""):
|
||||
yield event
|
||||
return
|
||||
|
||||
if ConversationCommand.Webpage in conversation_commands:
|
||||
try:
|
||||
async for result in read_webpages(
|
||||
defiltered_query, meta_log, location, partial(send_event, "status")
|
||||
):
|
||||
if isinstance(result, dict) and "status" in result:
|
||||
yield result["status"]
|
||||
else:
|
||||
direct_web_pages = result
|
||||
webpages = []
|
||||
for query in direct_web_pages:
|
||||
if online_results.get(query):
|
||||
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
|
||||
else:
|
||||
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
|
||||
|
||||
for webpage in direct_web_pages[query]["webpages"]:
|
||||
webpages.append(webpage["link"])
|
||||
async for result in send_event("status", f"**📚 Read web pages**: {webpages}"):
|
||||
yield result
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
f"Error directly reading webpages: {e}. Attempting to respond without online results",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if ConversationCommand.Image in conversation_commands:
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
)
|
||||
async for result in text_to_image(
|
||||
q,
|
||||
user,
|
||||
meta_log,
|
||||
location_data=location,
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
send_status_func=partial(send_event, "status"),
|
||||
):
|
||||
if isinstance(result, dict) and "status" in result:
|
||||
yield result["status"]
|
||||
else:
|
||||
image, status_code, improved_image_prompt, intent_type = result
|
||||
|
||||
if image is None or status_code != 200:
|
||||
content_obj = {
|
||||
"image": image,
|
||||
"intentType": intent_type,
|
||||
"detail": improved_image_prompt,
|
||||
"content-type": "application/json",
|
||||
}
|
||||
async for result in send_event("complete_llm_response", json.dumps(content_obj)):
|
||||
yield result
|
||||
async for event in send_event("end_llm_response", ""):
|
||||
yield event
|
||||
return
|
||||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
image,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
intent_type=intent_type,
|
||||
inferred_queries=[improved_image_prompt],
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
)
|
||||
content_obj = {
|
||||
"image": image,
|
||||
"intentType": intent_type,
|
||||
"inferredQueries": [improved_image_prompt],
|
||||
"context": compiled_references,
|
||||
"content-type": "application/json",
|
||||
"online_results": online_results,
|
||||
}
|
||||
async for result in send_event("complete_llm_response", json.dumps(content_obj)):
|
||||
yield result
|
||||
async for event in send_event("end_llm_response", ""):
|
||||
yield event
|
||||
return
|
||||
|
||||
async for result in send_event("status", f"**💭 Generating a well-informed response**"):
|
||||
yield result
|
||||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
conversation,
|
||||
compiled_references,
|
||||
online_results,
|
||||
inferred_queries,
|
||||
conversation_commands,
|
||||
user,
|
||||
request.user.client_app,
|
||||
conversation_id,
|
||||
location,
|
||||
user_name,
|
||||
)
|
||||
|
||||
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata=chat_metadata,
|
||||
)
|
||||
iterator = AsyncIteratorWrapper(llm_response)
|
||||
|
||||
async for result in send_event("start_llm_response", ""):
|
||||
yield result
|
||||
|
||||
async for item in iterator:
|
||||
if item is None:
|
||||
break
|
||||
if connection_alive:
|
||||
try:
|
||||
async for result in send_event("message", f"{item}"):
|
||||
yield result
|
||||
except Exception as e:
|
||||
connection_alive = False
|
||||
logger.info(
|
||||
f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}"
|
||||
)
|
||||
async for result in send_event("end_llm_response", ""):
|
||||
yield result
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in SSE endpoint: {e}", exc_info=True)
|
||||
break
|
||||
|
||||
return EventSourceResponse(event_generator(q))
|
||||
|
||||
|
||||
@api_chat.get("", response_class=Response)
|
||||
|
|
|
@ -755,7 +755,7 @@ async def text_to_image(
|
|||
references: List[Dict[str, Any]],
|
||||
online_results: Dict[str, Any],
|
||||
send_status_func: Optional[Callable] = None,
|
||||
) -> Tuple[Optional[str], int, Optional[str], str]:
|
||||
):
|
||||
status_code = 200
|
||||
image = None
|
||||
response = None
|
||||
|
@ -767,7 +767,8 @@ async def text_to_image(
|
|||
# If the user has not configured a text to image model, return an unsupported on server error
|
||||
status_code = 501
|
||||
message = "Failed to generate image. Setup image generation on the server."
|
||||
return image_url or image, status_code, message, intent_type.value
|
||||
yield image_url or image, status_code, message, intent_type.value
|
||||
return
|
||||
|
||||
text2image_model = text_to_image_config.model_name
|
||||
chat_history = ""
|
||||
|
@ -781,7 +782,8 @@ async def text_to_image(
|
|||
|
||||
with timer("Improve the original user query", logger):
|
||||
if send_status_func:
|
||||
await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
|
||||
async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"):
|
||||
yield {"status": event}
|
||||
improved_image_prompt = await generate_better_image_prompt(
|
||||
message,
|
||||
chat_history,
|
||||
|
@ -792,7 +794,8 @@ async def text_to_image(
|
|||
)
|
||||
|
||||
if send_status_func:
|
||||
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
|
||||
async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"):
|
||||
yield {"status": event}
|
||||
|
||||
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||
with timer("Generate image with OpenAI", logger):
|
||||
|
@ -817,12 +820,14 @@ async def text_to_image(
|
|||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||
status_code = e.status_code # type: ignore
|
||||
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
||||
return image_url or image, status_code, message, intent_type.value
|
||||
yield image_url or image, status_code, message, intent_type.value
|
||||
return
|
||||
else:
|
||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
||||
status_code = e.status_code # type: ignore
|
||||
return image_url or image, status_code, message, intent_type.value
|
||||
yield image_url or image, status_code, message, intent_type.value
|
||||
return
|
||||
|
||||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
||||
with timer("Generate image with Stability AI", logger):
|
||||
|
@ -844,7 +849,8 @@ async def text_to_image(
|
|||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||
message = f"Image generation failed with Stability AI error: {e}"
|
||||
status_code = e.status_code # type: ignore
|
||||
return image_url or image, status_code, message, intent_type.value
|
||||
yield image_url or image, status_code, message, intent_type.value
|
||||
return
|
||||
|
||||
with timer("Convert image to webp", logger):
|
||||
# Convert png to webp for faster loading
|
||||
|
@ -864,7 +870,7 @@ async def text_to_image(
|
|||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
||||
|
||||
return image_url or image, status_code, improved_image_prompt, intent_type.value
|
||||
yield image_url or image, status_code, improved_image_prompt, intent_type.value
|
||||
|
||||
|
||||
class ApiUserRateLimiter:
|
||||
|
|
Loading…
Add table
Reference in a new issue