mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Use enum to track chat stream event types in chat api router
This commit is contained in:
parent
ebe92ef16d
commit
778c571288
4 changed files with 56 additions and 42 deletions
|
@ -11,6 +11,7 @@ from bs4 import BeautifulSoup
|
|||
from markdownify import markdownify
|
||||
|
||||
from khoj.routers.helpers import (
|
||||
ChatEvent,
|
||||
extract_relevant_info,
|
||||
generate_online_subqueries,
|
||||
infer_webpage_urls,
|
||||
|
@ -68,7 +69,7 @@ async def search_online(
|
|||
if send_status_func:
|
||||
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
|
||||
async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"):
|
||||
yield {"status": event}
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
with timer(f"Internet searches for {list(subqueries)} took", logger):
|
||||
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
|
||||
|
@ -92,7 +93,7 @@ async def search_online(
|
|||
if send_status_func:
|
||||
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
|
||||
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
|
||||
yield {"status": event}
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
|
@ -131,14 +132,14 @@ async def read_webpages(
|
|||
logger.info(f"Inferring web pages to read")
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**🧐 Inferring web pages to read**"):
|
||||
yield {"status": event}
|
||||
yield {ChatEvent.STATUS: event}
|
||||
urls = await infer_webpage_urls(query, conversation_history, location)
|
||||
|
||||
logger.info(f"Reading web pages at: {urls}")
|
||||
if send_status_func:
|
||||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
|
||||
yield {"status": event}
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [read_webpage_and_extract_content(query, url) for url in urls]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ from khoj.processor.conversation.openai.gpt import extract_questions
|
|||
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
||||
from khoj.routers.helpers import (
|
||||
ApiUserRateLimiter,
|
||||
ChatEvent,
|
||||
CommonQueryParams,
|
||||
ConversationCommandRateLimiter,
|
||||
acreate_title_from_query,
|
||||
|
@ -375,7 +376,7 @@ async def extract_references_and_questions(
|
|||
if send_status_func:
|
||||
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
|
||||
async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"):
|
||||
yield {"status": event}
|
||||
yield {ChatEvent.STATUS: event}
|
||||
for query in inferred_queries:
|
||||
n_items = min(n, 3) if using_offline_chat else n
|
||||
search_results.extend(
|
||||
|
|
|
@ -30,6 +30,7 @@ from khoj.processor.tools.online_search import read_webpages, search_online
|
|||
from khoj.routers.api import extract_references_and_questions
|
||||
from khoj.routers.helpers import (
|
||||
ApiUserRateLimiter,
|
||||
ChatEvent,
|
||||
CommonQueryParams,
|
||||
ConversationCommandRateLimiter,
|
||||
agenerate_chat_response,
|
||||
|
@ -551,24 +552,24 @@ async def chat(
|
|||
event_delimiter = "␃🔚␗"
|
||||
q = unquote(q)
|
||||
|
||||
async def send_event(event_type: str, data: str | dict):
|
||||
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||
nonlocal connection_alive, ttft
|
||||
if not connection_alive or await request.is_disconnected():
|
||||
connection_alive = False
|
||||
logger.warn(f"User {user} disconnected from {common.client} client")
|
||||
return
|
||||
try:
|
||||
if event_type == "end_llm_response":
|
||||
if event_type == ChatEvent.END_LLM_RESPONSE:
|
||||
collect_telemetry()
|
||||
if event_type == "start_llm_response":
|
||||
if event_type == ChatEvent.START_LLM_RESPONSE:
|
||||
ttft = time.perf_counter() - start_time
|
||||
if event_type == "message":
|
||||
if event_type == ChatEvent.MESSAGE:
|
||||
yield data
|
||||
elif event_type == "references" or stream:
|
||||
yield json.dumps({"type": event_type, "data": data}, ensure_ascii=False)
|
||||
except asyncio.CancelledError:
|
||||
elif event_type == ChatEvent.REFERENCES or stream:
|
||||
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
|
||||
except asyncio.CancelledError as e:
|
||||
connection_alive = False
|
||||
logger.warn(f"User {user} disconnected from {common.client} client")
|
||||
logger.warn(f"User {user} disconnected from {common.client} client: {e}")
|
||||
return
|
||||
except Exception as e:
|
||||
connection_alive = False
|
||||
|
@ -579,11 +580,11 @@ async def chat(
|
|||
yield event_delimiter
|
||||
|
||||
async def send_llm_response(response: str):
|
||||
async for result in send_event("start_llm_response", ""):
|
||||
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||
yield result
|
||||
async for result in send_event("message", response):
|
||||
async for result in send_event(ChatEvent.MESSAGE, response):
|
||||
yield result
|
||||
async for result in send_event("end_llm_response", ""):
|
||||
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||
yield result
|
||||
|
||||
def collect_telemetry():
|
||||
|
@ -632,7 +633,7 @@ async def chat(
|
|||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||
|
||||
async for result in send_event("status", f"**👀 Understanding Query**: {q}"):
|
||||
async for result in send_event(ChatEvent.STATUS, f"**👀 Understanding Query**: {q}"):
|
||||
yield result
|
||||
|
||||
meta_log = conversation.conversation_log
|
||||
|
@ -642,12 +643,12 @@ async def chat(
|
|||
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
|
||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||
async for result in send_event(
|
||||
"status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
|
||||
ChatEvent.STATUS, f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
|
||||
):
|
||||
yield result
|
||||
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
|
||||
async for result in send_event("status", f"**🧑🏾💻 Decided Response Mode:** {mode.value}"):
|
||||
async for result in send_event(ChatEvent.STATUS, f"**🧑🏾💻 Decided Response Mode:** {mode.value}"):
|
||||
yield result
|
||||
if mode not in conversation_commands:
|
||||
conversation_commands.append(mode)
|
||||
|
@ -690,7 +691,7 @@ async def chat(
|
|||
if not q:
|
||||
q = "Create a general summary of the file"
|
||||
async for result in send_event(
|
||||
"status", f"**🧑🏾💻 Constructing Summary Using:** {file_object[0].file_name}"
|
||||
ChatEvent.STATUS, f"**🧑🏾💻 Constructing Summary Using:** {file_object[0].file_name}"
|
||||
):
|
||||
yield result
|
||||
|
||||
|
@ -771,10 +772,10 @@ async def chat(
|
|||
conversation_id,
|
||||
conversation_commands,
|
||||
location,
|
||||
partial(send_event, "status"),
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
):
|
||||
if isinstance(result, dict) and "status" in result:
|
||||
yield result["status"]
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
compiled_references.extend(result[0])
|
||||
inferred_queries.extend(result[1])
|
||||
|
@ -782,7 +783,7 @@ async def chat(
|
|||
|
||||
if not is_none_or_empty(compiled_references):
|
||||
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
||||
async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"):
|
||||
async for result in send_event(ChatEvent.STATUS, f"**📜 Found Relevant Notes**: {headings}"):
|
||||
yield result
|
||||
|
||||
online_results: Dict = dict()
|
||||
|
@ -799,10 +800,10 @@ async def chat(
|
|||
if ConversationCommand.Online in conversation_commands:
|
||||
try:
|
||||
async for result in search_online(
|
||||
defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters
|
||||
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters
|
||||
):
|
||||
if isinstance(result, dict) and "status" in result:
|
||||
yield result["status"]
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
online_results = result
|
||||
except ValueError as e:
|
||||
|
@ -815,9 +816,11 @@ async def chat(
|
|||
## Gather Webpage References
|
||||
if ConversationCommand.Webpage in conversation_commands:
|
||||
try:
|
||||
async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")):
|
||||
if isinstance(result, dict) and "status" in result:
|
||||
yield result["status"]
|
||||
async for result in read_webpages(
|
||||
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS)
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
direct_web_pages = result
|
||||
webpages = []
|
||||
|
@ -829,7 +832,7 @@ async def chat(
|
|||
|
||||
for webpage in direct_web_pages[query]["webpages"]:
|
||||
webpages.append(webpage["link"])
|
||||
async for result in send_event("status", f"**📚 Read web pages**: {webpages}"):
|
||||
async for result in send_event(ChatEvent.STATUS, f"**📚 Read web pages**: {webpages}"):
|
||||
yield result
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
|
@ -839,7 +842,7 @@ async def chat(
|
|||
|
||||
## Send Gathered References
|
||||
async for result in send_event(
|
||||
"references",
|
||||
ChatEvent.REFERENCES,
|
||||
{
|
||||
"inferredQueries": inferred_queries,
|
||||
"context": compiled_references,
|
||||
|
@ -858,10 +861,10 @@ async def chat(
|
|||
location_data=location,
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
send_status_func=partial(send_event, "status"),
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
):
|
||||
if isinstance(result, dict) and "status" in result:
|
||||
yield result["status"]
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
image, status_code, improved_image_prompt, intent_type = result
|
||||
|
||||
|
@ -899,7 +902,7 @@ async def chat(
|
|||
return
|
||||
|
||||
## Generate Text Output
|
||||
async for result in send_event("status", f"**💭 Generating a well-informed response**"):
|
||||
async for result in send_event(ChatEvent.STATUS, f"**💭 Generating a well-informed response**"):
|
||||
yield result
|
||||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
defiltered_query,
|
||||
|
@ -917,21 +920,21 @@ async def chat(
|
|||
)
|
||||
|
||||
# Send Response
|
||||
async for result in send_event("start_llm_response", ""):
|
||||
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||
yield result
|
||||
|
||||
continue_stream = True
|
||||
iterator = AsyncIteratorWrapper(llm_response)
|
||||
async for item in iterator:
|
||||
if item is None:
|
||||
async for result in send_event("end_llm_response", ""):
|
||||
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||
yield result
|
||||
logger.debug("Finished streaming response")
|
||||
return
|
||||
if not connection_alive or not continue_stream:
|
||||
continue
|
||||
try:
|
||||
async for result in send_event("message", f"{item}"):
|
||||
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
|
||||
yield result
|
||||
except Exception as e:
|
||||
continue_stream = False
|
||||
|
@ -949,7 +952,7 @@ async def chat(
|
|||
async for item in iterator:
|
||||
try:
|
||||
item_json = json.loads(item)
|
||||
if "type" in item_json and item_json["type"] == "references":
|
||||
if "type" in item_json and item_json["type"] == ChatEvent.REFERENCES.value:
|
||||
response_obj = item_json["data"]
|
||||
except:
|
||||
actual_response += item
|
||||
|
|
|
@ -8,6 +8,7 @@ import math
|
|||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from random import random
|
||||
from typing import (
|
||||
|
@ -782,7 +783,7 @@ async def text_to_image(
|
|||
|
||||
if send_status_func:
|
||||
async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"):
|
||||
yield {"status": event}
|
||||
yield {ChatEvent.STATUS: event}
|
||||
improved_image_prompt = await generate_better_image_prompt(
|
||||
message,
|
||||
chat_history,
|
||||
|
@ -794,7 +795,7 @@ async def text_to_image(
|
|||
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"):
|
||||
yield {"status": event}
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||
with timer("Generate image with OpenAI", logger):
|
||||
|
@ -1191,3 +1192,11 @@ def construct_automation_created_message(automation: Job, crontime: str, query_t
|
|||
|
||||
Manage your automations [here](/automations).
|
||||
""".strip()
|
||||
|
||||
|
||||
class ChatEvent(Enum):
|
||||
START_LLM_RESPONSE = "start_llm_response"
|
||||
END_LLM_RESPONSE = "end_llm_response"
|
||||
MESSAGE = "message"
|
||||
REFERENCES = "references"
|
||||
STATUS = "status"
|
||||
|
|
Loading…
Add table
Reference in a new issue