Use enum to track chat stream event types in chat api router

This commit is contained in:
Debanjum Singh Solanky 2024-07-26 00:18:37 +05:30
parent ebe92ef16d
commit 778c571288
4 changed files with 56 additions and 42 deletions

View file

@ -11,6 +11,7 @@ from bs4 import BeautifulSoup
from markdownify import markdownify
from khoj.routers.helpers import (
ChatEvent,
extract_relevant_info,
generate_online_subqueries,
infer_webpage_urls,
@ -68,7 +69,7 @@ async def search_online(
if send_status_func:
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"):
yield {"status": event}
yield {ChatEvent.STATUS: event}
with timer(f"Internet searches for {list(subqueries)} took", logger):
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
@ -92,7 +93,7 @@ async def search_online(
if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
yield {"status": event}
yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
results = await asyncio.gather(*tasks)
@ -131,14 +132,14 @@ async def read_webpages(
logger.info(f"Inferring web pages to read")
if send_status_func:
async for event in send_status_func(f"**🧐 Inferring web pages to read**"):
yield {"status": event}
yield {ChatEvent.STATUS: event}
urls = await infer_webpage_urls(query, conversation_history, location)
logger.info(f"Reading web pages at: {urls}")
if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
yield {"status": event}
yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(query, url) for url in urls]
results = await asyncio.gather(*tasks)

View file

@ -36,6 +36,7 @@ from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.routers.helpers import (
ApiUserRateLimiter,
ChatEvent,
CommonQueryParams,
ConversationCommandRateLimiter,
acreate_title_from_query,
@ -375,7 +376,7 @@ async def extract_references_and_questions(
if send_status_func:
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"):
yield {"status": event}
yield {ChatEvent.STATUS: event}
for query in inferred_queries:
n_items = min(n, 3) if using_offline_chat else n
search_results.extend(

View file

@ -30,6 +30,7 @@ from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import (
ApiUserRateLimiter,
ChatEvent,
CommonQueryParams,
ConversationCommandRateLimiter,
agenerate_chat_response,
@ -551,24 +552,24 @@ async def chat(
event_delimiter = "␃🔚␗"
q = unquote(q)
async def send_event(event_type: str, data: str | dict):
async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive, ttft
if not connection_alive or await request.is_disconnected():
connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client")
return
try:
if event_type == "end_llm_response":
if event_type == ChatEvent.END_LLM_RESPONSE:
collect_telemetry()
if event_type == "start_llm_response":
if event_type == ChatEvent.START_LLM_RESPONSE:
ttft = time.perf_counter() - start_time
if event_type == "message":
if event_type == ChatEvent.MESSAGE:
yield data
elif event_type == "references" or stream:
yield json.dumps({"type": event_type, "data": data}, ensure_ascii=False)
except asyncio.CancelledError:
elif event_type == ChatEvent.REFERENCES or stream:
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
except asyncio.CancelledError as e:
connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client")
logger.warn(f"User {user} disconnected from {common.client} client: {e}")
return
except Exception as e:
connection_alive = False
@ -579,11 +580,11 @@ async def chat(
yield event_delimiter
async def send_llm_response(response: str):
async for result in send_event("start_llm_response", ""):
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
yield result
async for result in send_event("message", response):
async for result in send_event(ChatEvent.MESSAGE, response):
yield result
async for result in send_event("end_llm_response", ""):
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result
def collect_telemetry():
@ -632,7 +633,7 @@ async def chat(
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_commands = [get_conversation_command(query=q, any_references=True)]
async for result in send_event("status", f"**👀 Understanding Query**: {q}"):
async for result in send_event(ChatEvent.STATUS, f"**👀 Understanding Query**: {q}"):
yield result
meta_log = conversation.conversation_log
@ -642,12 +643,12 @@ async def chat(
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event(
"status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
ChatEvent.STATUS, f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"):
async for result in send_event(ChatEvent.STATUS, f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
conversation_commands.append(mode)
@ -690,7 +691,7 @@ async def chat(
if not q:
q = "Create a general summary of the file"
async for result in send_event(
"status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}"
ChatEvent.STATUS, f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}"
):
yield result
@ -771,10 +772,10 @@ async def chat(
conversation_id,
conversation_commands,
location,
partial(send_event, "status"),
partial(send_event, ChatEvent.STATUS),
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
@ -782,7 +783,7 @@ async def chat(
if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"):
async for result in send_event(ChatEvent.STATUS, f"**📜 Found Relevant Notes**: {headings}"):
yield result
online_results: Dict = dict()
@ -799,10 +800,10 @@ async def chat(
if ConversationCommand.Online in conversation_commands:
try:
async for result in search_online(
defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
online_results = result
except ValueError as e:
@ -815,9 +816,11 @@ async def chat(
## Gather Webpage References
if ConversationCommand.Webpage in conversation_commands:
try:
async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")):
if isinstance(result, dict) and "status" in result:
yield result["status"]
async for result in read_webpages(
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS)
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
direct_web_pages = result
webpages = []
@ -829,7 +832,7 @@ async def chat(
for webpage in direct_web_pages[query]["webpages"]:
webpages.append(webpage["link"])
async for result in send_event("status", f"**📚 Read web pages**: {webpages}"):
async for result in send_event(ChatEvent.STATUS, f"**📚 Read web pages**: {webpages}"):
yield result
except ValueError as e:
logger.warning(
@ -839,7 +842,7 @@ async def chat(
## Send Gathered References
async for result in send_event(
"references",
ChatEvent.REFERENCES,
{
"inferredQueries": inferred_queries,
"context": compiled_references,
@ -858,10 +861,10 @@ async def chat(
location_data=location,
references=compiled_references,
online_results=online_results,
send_status_func=partial(send_event, "status"),
send_status_func=partial(send_event, ChatEvent.STATUS),
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
image, status_code, improved_image_prompt, intent_type = result
@ -899,7 +902,7 @@ async def chat(
return
## Generate Text Output
async for result in send_event("status", f"**💭 Generating a well-informed response**"):
async for result in send_event(ChatEvent.STATUS, f"**💭 Generating a well-informed response**"):
yield result
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
@ -917,21 +920,21 @@ async def chat(
)
# Send Response
async for result in send_event("start_llm_response", ""):
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
yield result
continue_stream = True
iterator = AsyncIteratorWrapper(llm_response)
async for item in iterator:
if item is None:
async for result in send_event("end_llm_response", ""):
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result
logger.debug("Finished streaming response")
return
if not connection_alive or not continue_stream:
continue
try:
async for result in send_event("message", f"{item}"):
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
yield result
except Exception as e:
continue_stream = False
@ -949,7 +952,7 @@ async def chat(
async for item in iterator:
try:
item_json = json.loads(item)
if "type" in item_json and item_json["type"] == "references":
if "type" in item_json and item_json["type"] == ChatEvent.REFERENCES.value:
response_obj = item_json["data"]
except:
actual_response += item

View file

@ -8,6 +8,7 @@ import math
import re
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
from enum import Enum
from functools import partial
from random import random
from typing import (
@ -782,7 +783,7 @@ async def text_to_image(
if send_status_func:
async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"):
yield {"status": event}
yield {ChatEvent.STATUS: event}
improved_image_prompt = await generate_better_image_prompt(
message,
chat_history,
@ -794,7 +795,7 @@ async def text_to_image(
if send_status_func:
async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"):
yield {"status": event}
yield {ChatEvent.STATUS: event}
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
with timer("Generate image with OpenAI", logger):
@ -1191,3 +1192,11 @@ def construct_automation_created_message(automation: Job, crontime: str, query_t
Manage your automations [here](/automations).
""".strip()
class ChatEvent(Enum):
START_LLM_RESPONSE = "start_llm_response"
END_LLM_RESPONSE = "end_llm_response"
MESSAGE = "message"
REFERENCES = "references"
STATUS = "status"