Add support for chatting via the web socket connection

- Convert to a model of calling the search API directly with a function call (rather than using the API method)
- Gracefully handle websocket connection disconnects
- Ensure that the rest of the response is still saved, as it is currently, if the user disconects from the client
- Setup unchangeable context at the beginning of the session when the connection is established (like location, username, etc)
This commit is contained in:
sabaimran 2024-03-20 14:33:33 +05:30
parent 36af9776e6
commit a346f79b39
2 changed files with 250 additions and 11 deletions

View file

@ -61,6 +61,36 @@ async def search(
dedupe: Optional[bool] = True, dedupe: Optional[bool] = True,
): ):
user = request.user.object user = request.user.object
results = await execute_search(
user=user,
q=q,
n=n,
t=t,
r=r,
max_distance=max_distance,
dedupe=dedupe,
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="search",
**common.__dict__,
)
return results
async def execute_search(
user: KhojUser,
q: str,
n: Optional[int] = 5,
t: Optional[SearchType] = SearchType.All,
r: Optional[bool] = False,
max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True,
):
start_time = time.time() start_time = time.time()
# Run validation checks # Run validation checks
@ -155,13 +185,6 @@ async def search(
if user: if user:
state.query_cache[user.uuid][query_cache_key] = results state.query_cache[user.uuid][query_cache_key] = results
update_telemetry_state(
request=request,
telemetry_type="api",
api="search",
**common.__dict__,
)
end_time = time.time() end_time = time.time()
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds") logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
@ -349,14 +372,14 @@ async def extract_references_and_questions(
for query in inferred_queries: for query in inferred_queries:
n_items = min(n, 3) if using_offline_chat else n n_items = min(n, 3) if using_offline_chat else n
result_list.extend( result_list.extend(
await search( await execute_search(
user,
f"{query} {filters_in_query}", f"{query} {filters_in_query}",
request=request,
n=n_items, n=n_items,
t=SearchType.All,
r=True, r=True,
max_distance=d, max_distance=d,
dedupe=False, dedupe=False,
common=common,
) )
) )
result_list = text_search.deduplicated_search_responses(result_list) result_list = text_search.deduplicated_search_responses(result_list)

View file

@ -5,10 +5,12 @@ from typing import Dict, Optional
from urllib.parse import unquote from urllib.parse import unquote
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request, WebSocket
from fastapi.requests import Request from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires from starlette.authentication import requires
from starlette.websockets import WebSocketDisconnect
from websockets import ConnectionClosedOK
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
from khoj.database.models import KhojUser from khoj.database.models import KhojUser
@ -229,6 +231,220 @@ async def set_conversation_title(
) )
@api_chat.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
conversation_id: int,
city: Optional[str] = None,
region: Optional[str] = None,
country: Optional[str] = None,
):
connection_alive = True
async def send_status_update(message: str):
nonlocal connection_alive
if not connection_alive:
return
status_packet = {
"type": "status",
"message": message,
"content-type": "application/json",
}
try:
await websocket.send_text(json.dumps(status_packet))
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
async def send_complete_llm_response(llm_response: str):
nonlocal connection_alive
if not connection_alive:
return
try:
await websocket.send_text("start_llm_response")
await websocket.send_text(llm_response)
await websocket.send_text("end_llm_response")
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
user: KhojUser = websocket.user.object
conversation = await ConversationAdapters.aget_conversation_by_user(
user, client_application=websocket.user.client_app, conversation_id=conversation_id
)
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
await is_ready_to_chat(user)
user_name = await aget_user_name(user)
location = None
if city or region or country:
location = LocationData(city=city, region=region, country=country)
await websocket.accept()
while connection_alive:
try:
q = await websocket.receive_text()
except WebSocketDisconnect:
logger.debug(f"User {user} disconnected web socket")
break
await sync_to_async(hourly_limiter)(websocket)
await sync_to_async(daily_limiter)(websocket)
conversation_commands = [get_conversation_command(query=q, any_references=True)]
await send_status_update(f"**Processing query**: {q}")
if conversation_commands == [ConversationCommand.Help]:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
await send_complete_llm_response(formatted_help)
continue
meta_log = conversation.conversation_log
if conversation_commands == [ConversationCommand.Default]:
conversation_commands = await aget_relevant_information_sources(q, meta_log)
mode = await aget_relevant_output_modes(q, meta_log)
if mode not in conversation_commands:
conversation_commands.append(mode)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
await send_status_update(
f"**Using conversation commands:** {', '.join([cmd.value for cmd in conversation_commands])}"
)
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
websocket, None, meta_log, q, 7, 0.18, conversation_commands, location
)
if compiled_references:
headings = set([c.split("\n")[0] for c in compiled_references])
await send_status_update(f"**Searching references**: {headings}")
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
await send_complete_llm_response(f"{no_entries_found.format()}")
continue
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands:
try:
await send_status_update("Searching the web for relevant information 🌐")
online_results = await search_online(defiltered_query, meta_log, location)
online_searches = "".join([f"{query}" for query in online_results.keys()])
await send_status_update(f"**Online searches**: {online_searches}")
except ValueError as e:
await send_complete_llm_response(
"Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"
)
continue
if ConversationCommand.Image in conversation_commands:
update_telemetry_state(
request=websocket,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
intent_type = "text-to-image"
image, status_code, improved_image_prompt, image_url = await text_to_image(
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
)
if image is None or status_code != 200:
content_obj = {
"image": image,
"intentType": intent_type,
"detail": improved_image_prompt,
"content-type": "application/json",
}
await send_complete_llm_response(json.dumps(content_obj))
continue
if image_url:
intent_type = "text-to-image2"
image = image_url
await sync_to_async(save_to_conversation_log)(
q,
image,
user,
meta_log,
intent_type=intent_type,
inferred_queries=[improved_image_prompt],
client_application=websocket.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
)
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore
await send_complete_llm_response(json.dumps(content_obj))
continue
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
conversation,
compiled_references,
online_results,
inferred_queries,
conversation_commands,
user,
websocket.user.client_app,
conversation_id,
location,
user_name,
)
update_telemetry_state(
request=websocket,
telemetry_type="api",
api="chat",
metadata=chat_metadata,
)
iterator = AsyncIteratorWrapper(llm_response)
if connection_alive:
try:
await websocket.send_text("start_llm_response")
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
async for item in iterator:
if item is None:
break
if connection_alive:
try:
await websocket.send_text(f"{item}")
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
if connection_alive:
try:
await websocket.send_text("end_llm_response")
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
@api_chat.get("", response_class=Response) @api_chat.get("", response_class=Response)
@requires(["authenticated"]) @requires(["authenticated"])
async def chat( async def chat(