mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-24 07:55:07 +01:00
Add support for chatting via the web socket connection
- Convert to a model of calling the search API directly with a function call (rather than using the API method) - Gracefully handle websocket connection disconnects - Ensure that the rest of the response is still saved, as it is currently, if the user disconects from the client - Setup unchangeable context at the beginning of the session when the connection is established (like location, username, etc)
This commit is contained in:
parent
36af9776e6
commit
a346f79b39
2 changed files with 250 additions and 11 deletions
|
@ -61,6 +61,36 @@ async def search(
|
|||
dedupe: Optional[bool] = True,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
results = await execute_search(
|
||||
user=user,
|
||||
q=q,
|
||||
n=n,
|
||||
t=t,
|
||||
r=r,
|
||||
max_distance=max_distance,
|
||||
dedupe=dedupe,
|
||||
)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="search",
|
||||
**common.__dict__,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def execute_search(
|
||||
user: KhojUser,
|
||||
q: str,
|
||||
n: Optional[int] = 5,
|
||||
t: Optional[SearchType] = SearchType.All,
|
||||
r: Optional[bool] = False,
|
||||
max_distance: Optional[Union[float, None]] = None,
|
||||
dedupe: Optional[bool] = True,
|
||||
):
|
||||
start_time = time.time()
|
||||
|
||||
# Run validation checks
|
||||
|
@ -155,13 +185,6 @@ async def search(
|
|||
if user:
|
||||
state.query_cache[user.uuid][query_cache_key] = results
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="search",
|
||||
**common.__dict__,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
||||
|
||||
|
@ -349,14 +372,14 @@ async def extract_references_and_questions(
|
|||
for query in inferred_queries:
|
||||
n_items = min(n, 3) if using_offline_chat else n
|
||||
result_list.extend(
|
||||
await search(
|
||||
await execute_search(
|
||||
user,
|
||||
f"{query} {filters_in_query}",
|
||||
request=request,
|
||||
n=n_items,
|
||||
t=SearchType.All,
|
||||
r=True,
|
||||
max_distance=d,
|
||||
dedupe=False,
|
||||
common=common,
|
||||
)
|
||||
)
|
||||
result_list = text_search.deduplicated_search_responses(result_list)
|
||||
|
|
|
@ -5,10 +5,12 @@ from typing import Dict, Optional
|
|||
from urllib.parse import unquote
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi import APIRouter, Depends, Request, WebSocket
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from starlette.authentication import requires
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
from websockets import ConnectionClosedOK
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
|
||||
from khoj.database.models import KhojUser
|
||||
|
@ -229,6 +231,220 @@ async def set_conversation_title(
|
|||
)
|
||||
|
||||
|
||||
@api_chat.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
conversation_id: int,
|
||||
city: Optional[str] = None,
|
||||
region: Optional[str] = None,
|
||||
country: Optional[str] = None,
|
||||
):
|
||||
connection_alive = True
|
||||
|
||||
async def send_status_update(message: str):
|
||||
nonlocal connection_alive
|
||||
if not connection_alive:
|
||||
return
|
||||
|
||||
status_packet = {
|
||||
"type": "status",
|
||||
"message": message,
|
||||
"content-type": "application/json",
|
||||
}
|
||||
try:
|
||||
await websocket.send_text(json.dumps(status_packet))
|
||||
except ConnectionClosedOK:
|
||||
connection_alive = False
|
||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||
|
||||
async def send_complete_llm_response(llm_response: str):
|
||||
nonlocal connection_alive
|
||||
if not connection_alive:
|
||||
return
|
||||
try:
|
||||
await websocket.send_text("start_llm_response")
|
||||
await websocket.send_text(llm_response)
|
||||
await websocket.send_text("end_llm_response")
|
||||
except ConnectionClosedOK:
|
||||
connection_alive = False
|
||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||
|
||||
user: KhojUser = websocket.user.object
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||
user, client_application=websocket.user.client_app, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
||||
|
||||
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||
|
||||
await is_ready_to_chat(user)
|
||||
|
||||
user_name = await aget_user_name(user)
|
||||
|
||||
location = None
|
||||
|
||||
if city or region or country:
|
||||
location = LocationData(city=city, region=region, country=country)
|
||||
|
||||
await websocket.accept()
|
||||
while connection_alive:
|
||||
try:
|
||||
q = await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
logger.debug(f"User {user} disconnected web socket")
|
||||
break
|
||||
|
||||
await sync_to_async(hourly_limiter)(websocket)
|
||||
await sync_to_async(daily_limiter)(websocket)
|
||||
|
||||
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||
|
||||
await send_status_update(f"**Processing query**: {q}")
|
||||
|
||||
if conversation_commands == [ConversationCommand.Help]:
|
||||
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||
if conversation_config == None:
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
model_type = conversation_config.model_type
|
||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||
await send_complete_llm_response(formatted_help)
|
||||
continue
|
||||
|
||||
meta_log = conversation.conversation_log
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default]:
|
||||
conversation_commands = await aget_relevant_information_sources(q, meta_log)
|
||||
mode = await aget_relevant_output_modes(q, meta_log)
|
||||
if mode not in conversation_commands:
|
||||
conversation_commands.append(mode)
|
||||
|
||||
for cmd in conversation_commands:
|
||||
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
|
||||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
|
||||
await send_status_update(
|
||||
f"**Using conversation commands:** {', '.join([cmd.value for cmd in conversation_commands])}"
|
||||
)
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
websocket, None, meta_log, q, 7, 0.18, conversation_commands, location
|
||||
)
|
||||
|
||||
if compiled_references:
|
||||
headings = set([c.split("\n")[0] for c in compiled_references])
|
||||
await send_status_update(f"**Searching references**: {headings}")
|
||||
|
||||
online_results: Dict = dict()
|
||||
|
||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||
await send_complete_llm_response(f"{no_entries_found.format()}")
|
||||
continue
|
||||
|
||||
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
||||
conversation_commands.remove(ConversationCommand.Notes)
|
||||
|
||||
if ConversationCommand.Online in conversation_commands:
|
||||
try:
|
||||
await send_status_update("Searching the web for relevant information 🌐")
|
||||
online_results = await search_online(defiltered_query, meta_log, location)
|
||||
online_searches = "".join([f"{query}" for query in online_results.keys()])
|
||||
await send_status_update(f"**Online searches**: {online_searches}")
|
||||
except ValueError as e:
|
||||
await send_complete_llm_response(
|
||||
"Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"
|
||||
)
|
||||
continue
|
||||
|
||||
if ConversationCommand.Image in conversation_commands:
|
||||
update_telemetry_state(
|
||||
request=websocket,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
)
|
||||
intent_type = "text-to-image"
|
||||
image, status_code, improved_image_prompt, image_url = await text_to_image(
|
||||
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
|
||||
)
|
||||
if image is None or status_code != 200:
|
||||
content_obj = {
|
||||
"image": image,
|
||||
"intentType": intent_type,
|
||||
"detail": improved_image_prompt,
|
||||
"content-type": "application/json",
|
||||
}
|
||||
await send_complete_llm_response(json.dumps(content_obj))
|
||||
continue
|
||||
|
||||
if image_url:
|
||||
intent_type = "text-to-image2"
|
||||
image = image_url
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
image,
|
||||
user,
|
||||
meta_log,
|
||||
intent_type=intent_type,
|
||||
inferred_queries=[improved_image_prompt],
|
||||
client_application=websocket.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
)
|
||||
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore
|
||||
|
||||
await send_complete_llm_response(json.dumps(content_obj))
|
||||
continue
|
||||
|
||||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
conversation,
|
||||
compiled_references,
|
||||
online_results,
|
||||
inferred_queries,
|
||||
conversation_commands,
|
||||
user,
|
||||
websocket.user.client_app,
|
||||
conversation_id,
|
||||
location,
|
||||
user_name,
|
||||
)
|
||||
|
||||
update_telemetry_state(
|
||||
request=websocket,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata=chat_metadata,
|
||||
)
|
||||
iterator = AsyncIteratorWrapper(llm_response)
|
||||
|
||||
if connection_alive:
|
||||
try:
|
||||
await websocket.send_text("start_llm_response")
|
||||
except ConnectionClosedOK:
|
||||
connection_alive = False
|
||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||
|
||||
async for item in iterator:
|
||||
if item is None:
|
||||
break
|
||||
if connection_alive:
|
||||
try:
|
||||
await websocket.send_text(f"{item}")
|
||||
except ConnectionClosedOK:
|
||||
connection_alive = False
|
||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||
|
||||
if connection_alive:
|
||||
try:
|
||||
await websocket.send_text("end_llm_response")
|
||||
except ConnectionClosedOK:
|
||||
connection_alive = False
|
||||
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||
|
||||
|
||||
@api_chat.get("", response_class=Response)
|
||||
@requires(["authenticated"])
|
||||
async def chat(
|
||||
|
|
Loading…
Reference in a new issue