mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-28 01:45:07 +01:00
Add support for chatting via the web socket connection
- Convert to a model of calling the search API directly with a function call (rather than using the API method) - Gracefully handle websocket connection disconnects - Ensure that the rest of the response is still saved, as it is currently, if the user disconects from the client - Setup unchangeable context at the beginning of the session when the connection is established (like location, username, etc)
This commit is contained in:
parent
36af9776e6
commit
a346f79b39
2 changed files with 250 additions and 11 deletions
|
@ -61,6 +61,36 @@ async def search(
|
||||||
dedupe: Optional[bool] = True,
|
dedupe: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
|
|
||||||
|
results = await execute_search(
|
||||||
|
user=user,
|
||||||
|
q=q,
|
||||||
|
n=n,
|
||||||
|
t=t,
|
||||||
|
r=r,
|
||||||
|
max_distance=max_distance,
|
||||||
|
dedupe=dedupe,
|
||||||
|
)
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="search",
|
||||||
|
**common.__dict__,
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_search(
|
||||||
|
user: KhojUser,
|
||||||
|
q: str,
|
||||||
|
n: Optional[int] = 5,
|
||||||
|
t: Optional[SearchType] = SearchType.All,
|
||||||
|
r: Optional[bool] = False,
|
||||||
|
max_distance: Optional[Union[float, None]] = None,
|
||||||
|
dedupe: Optional[bool] = True,
|
||||||
|
):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Run validation checks
|
# Run validation checks
|
||||||
|
@ -155,13 +185,6 @@ async def search(
|
||||||
if user:
|
if user:
|
||||||
state.query_cache[user.uuid][query_cache_key] = results
|
state.query_cache[user.uuid][query_cache_key] = results
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="search",
|
|
||||||
**common.__dict__,
|
|
||||||
)
|
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
||||||
|
|
||||||
|
@ -349,14 +372,14 @@ async def extract_references_and_questions(
|
||||||
for query in inferred_queries:
|
for query in inferred_queries:
|
||||||
n_items = min(n, 3) if using_offline_chat else n
|
n_items = min(n, 3) if using_offline_chat else n
|
||||||
result_list.extend(
|
result_list.extend(
|
||||||
await search(
|
await execute_search(
|
||||||
|
user,
|
||||||
f"{query} {filters_in_query}",
|
f"{query} {filters_in_query}",
|
||||||
request=request,
|
|
||||||
n=n_items,
|
n=n_items,
|
||||||
|
t=SearchType.All,
|
||||||
r=True,
|
r=True,
|
||||||
max_distance=d,
|
max_distance=d,
|
||||||
dedupe=False,
|
dedupe=False,
|
||||||
common=common,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result_list = text_search.deduplicated_search_responses(result_list)
|
result_list = text_search.deduplicated_search_responses(result_list)
|
||||||
|
|
|
@ -5,10 +5,12 @@ from typing import Dict, Optional
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from fastapi import APIRouter, Depends, Request
|
from fastapi import APIRouter, Depends, Request, WebSocket
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
from fastapi.responses import Response, StreamingResponse
|
from fastapi.responses import Response, StreamingResponse
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
|
from starlette.websockets import WebSocketDisconnect
|
||||||
|
from websockets import ConnectionClosedOK
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
|
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
|
||||||
from khoj.database.models import KhojUser
|
from khoj.database.models import KhojUser
|
||||||
|
@ -229,6 +231,220 @@ async def set_conversation_title(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_chat.websocket("/ws")
|
||||||
|
async def websocket_endpoint(
|
||||||
|
websocket: WebSocket,
|
||||||
|
conversation_id: int,
|
||||||
|
city: Optional[str] = None,
|
||||||
|
region: Optional[str] = None,
|
||||||
|
country: Optional[str] = None,
|
||||||
|
):
|
||||||
|
connection_alive = True
|
||||||
|
|
||||||
|
async def send_status_update(message: str):
|
||||||
|
nonlocal connection_alive
|
||||||
|
if not connection_alive:
|
||||||
|
return
|
||||||
|
|
||||||
|
status_packet = {
|
||||||
|
"type": "status",
|
||||||
|
"message": message,
|
||||||
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
await websocket.send_text(json.dumps(status_packet))
|
||||||
|
except ConnectionClosedOK:
|
||||||
|
connection_alive = False
|
||||||
|
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||||
|
|
||||||
|
async def send_complete_llm_response(llm_response: str):
|
||||||
|
nonlocal connection_alive
|
||||||
|
if not connection_alive:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await websocket.send_text("start_llm_response")
|
||||||
|
await websocket.send_text(llm_response)
|
||||||
|
await websocket.send_text("end_llm_response")
|
||||||
|
except ConnectionClosedOK:
|
||||||
|
connection_alive = False
|
||||||
|
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||||
|
|
||||||
|
user: KhojUser = websocket.user.object
|
||||||
|
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||||
|
user, client_application=websocket.user.client_app, conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
||||||
|
|
||||||
|
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||||
|
|
||||||
|
await is_ready_to_chat(user)
|
||||||
|
|
||||||
|
user_name = await aget_user_name(user)
|
||||||
|
|
||||||
|
location = None
|
||||||
|
|
||||||
|
if city or region or country:
|
||||||
|
location = LocationData(city=city, region=region, country=country)
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
while connection_alive:
|
||||||
|
try:
|
||||||
|
q = await websocket.receive_text()
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.debug(f"User {user} disconnected web socket")
|
||||||
|
break
|
||||||
|
|
||||||
|
await sync_to_async(hourly_limiter)(websocket)
|
||||||
|
await sync_to_async(daily_limiter)(websocket)
|
||||||
|
|
||||||
|
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||||
|
|
||||||
|
await send_status_update(f"**Processing query**: {q}")
|
||||||
|
|
||||||
|
if conversation_commands == [ConversationCommand.Help]:
|
||||||
|
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||||
|
if conversation_config == None:
|
||||||
|
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
model_type = conversation_config.model_type
|
||||||
|
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||||
|
await send_complete_llm_response(formatted_help)
|
||||||
|
continue
|
||||||
|
|
||||||
|
meta_log = conversation.conversation_log
|
||||||
|
|
||||||
|
if conversation_commands == [ConversationCommand.Default]:
|
||||||
|
conversation_commands = await aget_relevant_information_sources(q, meta_log)
|
||||||
|
mode = await aget_relevant_output_modes(q, meta_log)
|
||||||
|
if mode not in conversation_commands:
|
||||||
|
conversation_commands.append(mode)
|
||||||
|
|
||||||
|
for cmd in conversation_commands:
|
||||||
|
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
|
||||||
|
q = q.replace(f"/{cmd.value}", "").strip()
|
||||||
|
|
||||||
|
await send_status_update(
|
||||||
|
f"**Using conversation commands:** {', '.join([cmd.value for cmd in conversation_commands])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||||
|
websocket, None, meta_log, q, 7, 0.18, conversation_commands, location
|
||||||
|
)
|
||||||
|
|
||||||
|
if compiled_references:
|
||||||
|
headings = set([c.split("\n")[0] for c in compiled_references])
|
||||||
|
await send_status_update(f"**Searching references**: {headings}")
|
||||||
|
|
||||||
|
online_results: Dict = dict()
|
||||||
|
|
||||||
|
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||||
|
await send_complete_llm_response(f"{no_entries_found.format()}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
||||||
|
conversation_commands.remove(ConversationCommand.Notes)
|
||||||
|
|
||||||
|
if ConversationCommand.Online in conversation_commands:
|
||||||
|
try:
|
||||||
|
await send_status_update("Searching the web for relevant information 🌐")
|
||||||
|
online_results = await search_online(defiltered_query, meta_log, location)
|
||||||
|
online_searches = "".join([f"{query}" for query in online_results.keys()])
|
||||||
|
await send_status_update(f"**Online searches**: {online_searches}")
|
||||||
|
except ValueError as e:
|
||||||
|
await send_complete_llm_response(
|
||||||
|
"Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ConversationCommand.Image in conversation_commands:
|
||||||
|
update_telemetry_state(
|
||||||
|
request=websocket,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="chat",
|
||||||
|
metadata={"conversation_command": conversation_commands[0].value},
|
||||||
|
)
|
||||||
|
intent_type = "text-to-image"
|
||||||
|
image, status_code, improved_image_prompt, image_url = await text_to_image(
|
||||||
|
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
|
||||||
|
)
|
||||||
|
if image is None or status_code != 200:
|
||||||
|
content_obj = {
|
||||||
|
"image": image,
|
||||||
|
"intentType": intent_type,
|
||||||
|
"detail": improved_image_prompt,
|
||||||
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
await send_complete_llm_response(json.dumps(content_obj))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if image_url:
|
||||||
|
intent_type = "text-to-image2"
|
||||||
|
image = image_url
|
||||||
|
await sync_to_async(save_to_conversation_log)(
|
||||||
|
q,
|
||||||
|
image,
|
||||||
|
user,
|
||||||
|
meta_log,
|
||||||
|
intent_type=intent_type,
|
||||||
|
inferred_queries=[improved_image_prompt],
|
||||||
|
client_application=websocket.user.client_app,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
compiled_references=compiled_references,
|
||||||
|
online_results=online_results,
|
||||||
|
)
|
||||||
|
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore
|
||||||
|
|
||||||
|
await send_complete_llm_response(json.dumps(content_obj))
|
||||||
|
continue
|
||||||
|
|
||||||
|
llm_response, chat_metadata = await agenerate_chat_response(
|
||||||
|
defiltered_query,
|
||||||
|
meta_log,
|
||||||
|
conversation,
|
||||||
|
compiled_references,
|
||||||
|
online_results,
|
||||||
|
inferred_queries,
|
||||||
|
conversation_commands,
|
||||||
|
user,
|
||||||
|
websocket.user.client_app,
|
||||||
|
conversation_id,
|
||||||
|
location,
|
||||||
|
user_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=websocket,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="chat",
|
||||||
|
metadata=chat_metadata,
|
||||||
|
)
|
||||||
|
iterator = AsyncIteratorWrapper(llm_response)
|
||||||
|
|
||||||
|
if connection_alive:
|
||||||
|
try:
|
||||||
|
await websocket.send_text("start_llm_response")
|
||||||
|
except ConnectionClosedOK:
|
||||||
|
connection_alive = False
|
||||||
|
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||||
|
|
||||||
|
async for item in iterator:
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
if connection_alive:
|
||||||
|
try:
|
||||||
|
await websocket.send_text(f"{item}")
|
||||||
|
except ConnectionClosedOK:
|
||||||
|
connection_alive = False
|
||||||
|
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||||
|
|
||||||
|
if connection_alive:
|
||||||
|
try:
|
||||||
|
await websocket.send_text("end_llm_response")
|
||||||
|
except ConnectionClosedOK:
|
||||||
|
connection_alive = False
|
||||||
|
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||||
|
|
||||||
|
|
||||||
@api_chat.get("", response_class=Response)
|
@api_chat.get("", response_class=Response)
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def chat(
|
async def chat(
|
||||||
|
|
Loading…
Reference in a new issue