Keep the GET chat API endpoint for a bit before deprecating it

This is to avoid breaking non-updated Khoj clients
This commit is contained in:
Debanjum Singh Solanky 2024-09-11 16:14:19 -07:00
parent 241b9009ba
commit 3f51af9a96

View file

@ -3,6 +3,7 @@ import base64
import json import json
import logging import logging
import time import time
import warnings
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Dict, Optional from typing import Dict, Optional
@ -1002,3 +1003,478 @@ async def chat(
response_iterator = event_generator(q, image=image) response_iterator = event_generator(q, image=image)
response_data = await read_chat_stream(response_iterator) response_data = await read_chat_stream(response_iterator)
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
# Deprecated API. Remove by end of September 2024
@api_chat.get("")
@requires(["authenticated"])
async def get_chat(
request: Request,
common: CommonQueryParams,
q: str,
n: int = 7,
d: float = None,
stream: Optional[bool] = False,
title: Optional[str] = None,
conversation_id: Optional[int] = None,
city: Optional[str] = None,
region: Optional[str] = None,
country: Optional[str] = None,
timezone: Optional[str] = None,
image: Optional[str] = None,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=60, subscribed_requests=60, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=600, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
):
# Issue a deprecation warning
warnings.warn(
"The 'get_chat' API endpoint is deprecated. It will be removed by the end of September 2024.",
DeprecationWarning,
stacklevel=2,
)
async def event_generator(q: str, image: str):
start_time = time.perf_counter()
ttft = None
chat_metadata: dict = {}
connection_alive = True
user: KhojUser = request.user.object
subscribed: bool = has_required_scope(request, ["premium"])
event_delimiter = "␃🔚␗"
q = unquote(q)
nonlocal conversation_id
uploaded_image_url = None
if image:
decoded_string = unquote(image)
base64_data = decoded_string.split(",", 1)[1]
image_bytes = base64.b64decode(base64_data)
webp_image_bytes = convert_image_to_webp(image_bytes)
try:
uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
except:
uploaded_image_url = None
async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive, ttft
if not connection_alive or await request.is_disconnected():
connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client")
return
try:
if event_type == ChatEvent.END_LLM_RESPONSE:
collect_telemetry()
if event_type == ChatEvent.START_LLM_RESPONSE:
ttft = time.perf_counter() - start_time
if event_type == ChatEvent.MESSAGE:
yield data
elif event_type == ChatEvent.REFERENCES or stream:
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
except asyncio.CancelledError as e:
connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client: {e}")
return
except Exception as e:
connection_alive = False
logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
return
finally:
yield event_delimiter
async def send_llm_response(response: str):
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
yield result
async for result in send_event(ChatEvent.MESSAGE, response):
yield result
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result
def collect_telemetry():
# Gather chat response telemetry
nonlocal chat_metadata
latency = time.perf_counter() - start_time
cmd_set = set([cmd.value for cmd in conversation_commands])
chat_metadata = chat_metadata or {}
chat_metadata["conversation_command"] = cmd_set
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
chat_metadata["latency"] = f"{latency:.3f}"
chat_metadata["ttft_latency"] = f"{ttft:.3f}"
logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
logger.info(f"Chat response total time: {latency:.3f} seconds")
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
client=request.user.client_app,
user_agent=request.headers.get("user-agent"),
host=request.headers.get("host"),
metadata=chat_metadata,
)
conversation_commands = [get_conversation_command(query=q, any_references=True)]
conversation = await ConversationAdapters.aget_conversation_by_user(
user, client_application=request.user.client_app, conversation_id=conversation_id, title=title
)
if not conversation:
async for result in send_llm_response(f"Conversation {conversation_id} not found"):
yield result
return
conversation_id = conversation.id
await is_ready_to_chat(user)
user_name = await aget_user_name(user)
location = None
if city or region or country:
location = LocationData(city=city, region=region, country=country)
if is_query_empty(q):
async for result in send_llm_response("Please ask your query to get started."):
yield result
return
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(
q, meta_log, is_automated_task, subscribed=subscribed, uploaded_image_url=uploaded_image_url
)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event(
ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}"
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
conversation_commands.append(mode)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if
if (
# summarization intent was inferred
ConversationCommand.Summarize in conversation_commands
# and not triggered via slash command
and not used_slash_summarize
# but we can't actually summarize
and len(file_filters) != 1
):
conversation_commands.remove(ConversationCommand.Summarize)
elif ConversationCommand.Summarize in conversation_commands:
response_log = ""
if len(file_filters) == 0:
response_log = "No files selected for summarization. Please add files using the section on the left."
async for result in send_llm_response(response_log):
yield result
elif len(file_filters) > 1:
response_log = "Only one file can be selected for summarization."
async for result in send_llm_response(response_log):
yield result
else:
try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
async for result in send_llm_response(response_log):
yield result
return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
async for result in send_event(
ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}"
):
yield result
response = await extract_relevant_summary(
q, contextual_data, subscribed=subscribed, uploaded_image_url=uploaded_image_url
)
response_log = str(response)
async for result in send_llm_response(response_log):
yield result
except Exception as e:
response_log = "Error summarizing file."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
async for result in send_llm_response(response_log):
yield result
await sync_to_async(save_to_conversation_log)(
q,
response_log,
user,
meta_log,
user_message_time,
intent_type="summarize",
client_application=request.user.client_app,
conversation_id=conversation_id,
uploaded_image_url=uploaded_image_url,
)
return
custom_filters = []
if conversation_commands == [ConversationCommand.Help]:
if not q:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
async for result in send_llm_response(formatted_help):
yield result
return
# Adding specification to search online specifically on khoj.dev pages.
custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online)
if ConversationCommand.Automation in conversation_commands:
try:
automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, request.url, meta_log
)
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
async for result in send_llm_response(error_message):
yield result
return
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
await sync_to_async(save_to_conversation_log)(
q,
llm_response,
user,
meta_log,
user_message_time,
intent_type="automation",
client_application=request.user.client_app,
conversation_id=conversation_id,
inferred_queries=[query_to_run],
automation_id=automation.id,
uploaded_image_url=uploaded_image_url,
)
async for result in send_llm_response(llm_response):
yield result
return
# Gather Context
## Extract Document References
compiled_references, inferred_queries, defiltered_query = [], [], None
async for result in extract_references_and_questions(
request,
meta_log,
q,
(n or 7),
d,
conversation_id,
conversation_commands,
location,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
# Strip only leading # from headings
headings = headings.replace("#", "")
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
yield result
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
async for result in send_llm_response(f"{no_entries_found.format()}"):
yield result
return
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
## Gather Online References
if ConversationCommand.Online in conversation_commands:
try:
async for result in search_online(
defiltered_query,
meta_log,
location,
user,
subscribed,
partial(send_event, ChatEvent.STATUS),
custom_filters,
uploaded_image_url=uploaded_image_url,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
online_results = result
except ValueError as e:
error_message = f"Error searching online: {e}. Attempting to respond without online results"
logger.warning(error_message)
async for result in send_llm_response(error_message):
yield result
return
## Gather Webpage References
if ConversationCommand.Webpage in conversation_commands:
try:
async for result in read_webpages(
defiltered_query,
meta_log,
location,
user,
subscribed,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
direct_web_pages = result
webpages = []
for query in direct_web_pages:
if online_results.get(query):
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
else:
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
for webpage in direct_web_pages[query]["webpages"]:
webpages.append(webpage["link"])
async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
yield result
except ValueError as e:
logger.warning(
f"Error directly reading webpages: {e}. Attempting to respond without online results",
exc_info=True,
)
## Send Gathered References
async for result in send_event(
ChatEvent.REFERENCES,
{
"inferredQueries": inferred_queries,
"context": compiled_references,
"onlineContext": online_results,
},
):
yield result
# Generate Output
## Generate Image Output
if ConversationCommand.Image in conversation_commands:
async for result in text_to_image(
q,
user,
meta_log,
location_data=location,
references=compiled_references,
online_results=online_results,
subscribed=subscribed,
send_status_func=partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
image, status_code, improved_image_prompt, intent_type = result
if image is None or status_code != 200:
content_obj = {
"content-type": "application/json",
"intentType": intent_type,
"detail": improved_image_prompt,
"image": image,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
return
await sync_to_async(save_to_conversation_log)(
q,
image,
user,
meta_log,
user_message_time,
intent_type=intent_type,
inferred_queries=[improved_image_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
uploaded_image_url=uploaded_image_url,
)
content_obj = {
"intentType": intent_type,
"inferredQueries": [improved_image_prompt],
"image": image,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
return
## Generate Text Output
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
yield result
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
conversation,
compiled_references,
online_results,
inferred_queries,
conversation_commands,
user,
request.user.client_app,
conversation_id,
location,
user_name,
uploaded_image_url,
)
# Send Response
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
yield result
continue_stream = True
iterator = AsyncIteratorWrapper(llm_response)
async for item in iterator:
if item is None:
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result
logger.debug("Finished streaming response")
return
if not connection_alive or not continue_stream:
continue
try:
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
yield result
except Exception as e:
continue_stream = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
## Stream Text Response
if stream:
return StreamingResponse(event_generator(q, image=image), media_type="text/plain")
## Non-Streaming Text Response
else:
response_iterator = event_generator(q, image=image)
response_data = await read_chat_stream(response_iterator)
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)