mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-20 06:55:08 +00:00
Keep the GET chat API endpoint for a bit before deprecating it
This is to avoid breaking non-updated Khoj clients
This commit is contained in:
parent
241b9009ba
commit
3f51af9a96
1 changed files with 476 additions and 0 deletions
|
@ -3,6 +3,7 @@ import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import warnings
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
@ -1002,3 +1003,478 @@ async def chat(
|
||||||
response_iterator = event_generator(q, image=image)
|
response_iterator = event_generator(q, image=image)
|
||||||
response_data = await read_chat_stream(response_iterator)
|
response_data = await read_chat_stream(response_iterator)
|
||||||
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated API. Remove by end of September 2024
|
||||||
|
@api_chat.get("")
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def get_chat(
|
||||||
|
request: Request,
|
||||||
|
common: CommonQueryParams,
|
||||||
|
q: str,
|
||||||
|
n: int = 7,
|
||||||
|
d: float = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
conversation_id: Optional[int] = None,
|
||||||
|
city: Optional[str] = None,
|
||||||
|
region: Optional[str] = None,
|
||||||
|
country: Optional[str] = None,
|
||||||
|
timezone: Optional[str] = None,
|
||||||
|
image: Optional[str] = None,
|
||||||
|
rate_limiter_per_minute=Depends(
|
||||||
|
ApiUserRateLimiter(requests=60, subscribed_requests=60, window=60, slug="chat_minute")
|
||||||
|
),
|
||||||
|
rate_limiter_per_day=Depends(
|
||||||
|
ApiUserRateLimiter(requests=600, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||||
|
),
|
||||||
|
):
|
||||||
|
# Issue a deprecation warning
|
||||||
|
warnings.warn(
|
||||||
|
"The 'get_chat' API endpoint is deprecated. It will be removed by the end of September 2024.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def event_generator(q: str, image: str):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
ttft = None
|
||||||
|
chat_metadata: dict = {}
|
||||||
|
connection_alive = True
|
||||||
|
user: KhojUser = request.user.object
|
||||||
|
subscribed: bool = has_required_scope(request, ["premium"])
|
||||||
|
event_delimiter = "␃🔚␗"
|
||||||
|
q = unquote(q)
|
||||||
|
nonlocal conversation_id
|
||||||
|
|
||||||
|
uploaded_image_url = None
|
||||||
|
if image:
|
||||||
|
decoded_string = unquote(image)
|
||||||
|
base64_data = decoded_string.split(",", 1)[1]
|
||||||
|
image_bytes = base64.b64decode(base64_data)
|
||||||
|
webp_image_bytes = convert_image_to_webp(image_bytes)
|
||||||
|
try:
|
||||||
|
uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
|
||||||
|
except:
|
||||||
|
uploaded_image_url = None
|
||||||
|
|
||||||
|
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||||
|
nonlocal connection_alive, ttft
|
||||||
|
if not connection_alive or await request.is_disconnected():
|
||||||
|
connection_alive = False
|
||||||
|
logger.warn(f"User {user} disconnected from {common.client} client")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
if event_type == ChatEvent.END_LLM_RESPONSE:
|
||||||
|
collect_telemetry()
|
||||||
|
if event_type == ChatEvent.START_LLM_RESPONSE:
|
||||||
|
ttft = time.perf_counter() - start_time
|
||||||
|
if event_type == ChatEvent.MESSAGE:
|
||||||
|
yield data
|
||||||
|
elif event_type == ChatEvent.REFERENCES or stream:
|
||||||
|
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
|
||||||
|
except asyncio.CancelledError as e:
|
||||||
|
connection_alive = False
|
||||||
|
logger.warn(f"User {user} disconnected from {common.client} client: {e}")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
connection_alive = False
|
||||||
|
logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
yield event_delimiter
|
||||||
|
|
||||||
|
async def send_llm_response(response: str):
|
||||||
|
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||||
|
yield result
|
||||||
|
async for result in send_event(ChatEvent.MESSAGE, response):
|
||||||
|
yield result
|
||||||
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
def collect_telemetry():
|
||||||
|
# Gather chat response telemetry
|
||||||
|
nonlocal chat_metadata
|
||||||
|
latency = time.perf_counter() - start_time
|
||||||
|
cmd_set = set([cmd.value for cmd in conversation_commands])
|
||||||
|
chat_metadata = chat_metadata or {}
|
||||||
|
chat_metadata["conversation_command"] = cmd_set
|
||||||
|
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
|
||||||
|
chat_metadata["latency"] = f"{latency:.3f}"
|
||||||
|
chat_metadata["ttft_latency"] = f"{ttft:.3f}"
|
||||||
|
|
||||||
|
logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
|
||||||
|
logger.info(f"Chat response total time: {latency:.3f} seconds")
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="chat",
|
||||||
|
client=request.user.client_app,
|
||||||
|
user_agent=request.headers.get("user-agent"),
|
||||||
|
host=request.headers.get("host"),
|
||||||
|
metadata=chat_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||||
|
|
||||||
|
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||||
|
user, client_application=request.user.client_app, conversation_id=conversation_id, title=title
|
||||||
|
)
|
||||||
|
if not conversation:
|
||||||
|
async for result in send_llm_response(f"Conversation {conversation_id} not found"):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
conversation_id = conversation.id
|
||||||
|
|
||||||
|
await is_ready_to_chat(user)
|
||||||
|
|
||||||
|
user_name = await aget_user_name(user)
|
||||||
|
location = None
|
||||||
|
if city or region or country:
|
||||||
|
location = LocationData(city=city, region=region, country=country)
|
||||||
|
|
||||||
|
if is_query_empty(q):
|
||||||
|
async for result in send_llm_response("Please ask your query to get started."):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
|
||||||
|
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
meta_log = conversation.conversation_log
|
||||||
|
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
||||||
|
|
||||||
|
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||||
|
conversation_commands = await aget_relevant_information_sources(
|
||||||
|
q, meta_log, is_automated_task, subscribed=subscribed, uploaded_image_url=uploaded_image_url
|
||||||
|
)
|
||||||
|
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||||
|
async for result in send_event(
|
||||||
|
ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}"
|
||||||
|
):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url)
|
||||||
|
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||||
|
yield result
|
||||||
|
if mode not in conversation_commands:
|
||||||
|
conversation_commands.append(mode)
|
||||||
|
|
||||||
|
for cmd in conversation_commands:
|
||||||
|
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
||||||
|
q = q.replace(f"/{cmd.value}", "").strip()
|
||||||
|
|
||||||
|
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
|
||||||
|
file_filters = conversation.file_filters if conversation else []
|
||||||
|
# Skip trying to summarize if
|
||||||
|
if (
|
||||||
|
# summarization intent was inferred
|
||||||
|
ConversationCommand.Summarize in conversation_commands
|
||||||
|
# and not triggered via slash command
|
||||||
|
and not used_slash_summarize
|
||||||
|
# but we can't actually summarize
|
||||||
|
and len(file_filters) != 1
|
||||||
|
):
|
||||||
|
conversation_commands.remove(ConversationCommand.Summarize)
|
||||||
|
elif ConversationCommand.Summarize in conversation_commands:
|
||||||
|
response_log = ""
|
||||||
|
if len(file_filters) == 0:
|
||||||
|
response_log = "No files selected for summarization. Please add files using the section on the left."
|
||||||
|
async for result in send_llm_response(response_log):
|
||||||
|
yield result
|
||||||
|
elif len(file_filters) > 1:
|
||||||
|
response_log = "Only one file can be selected for summarization."
|
||||||
|
async for result in send_llm_response(response_log):
|
||||||
|
yield result
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||||
|
if len(file_object) == 0:
|
||||||
|
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
|
||||||
|
async for result in send_llm_response(response_log):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||||
|
if not q:
|
||||||
|
q = "Create a general summary of the file"
|
||||||
|
async for result in send_event(
|
||||||
|
ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}"
|
||||||
|
):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
response = await extract_relevant_summary(
|
||||||
|
q, contextual_data, subscribed=subscribed, uploaded_image_url=uploaded_image_url
|
||||||
|
)
|
||||||
|
response_log = str(response)
|
||||||
|
async for result in send_llm_response(response_log):
|
||||||
|
yield result
|
||||||
|
except Exception as e:
|
||||||
|
response_log = "Error summarizing file."
|
||||||
|
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
||||||
|
async for result in send_llm_response(response_log):
|
||||||
|
yield result
|
||||||
|
await sync_to_async(save_to_conversation_log)(
|
||||||
|
q,
|
||||||
|
response_log,
|
||||||
|
user,
|
||||||
|
meta_log,
|
||||||
|
user_message_time,
|
||||||
|
intent_type="summarize",
|
||||||
|
client_application=request.user.client_app,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
custom_filters = []
|
||||||
|
if conversation_commands == [ConversationCommand.Help]:
|
||||||
|
if not q:
|
||||||
|
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||||
|
if conversation_config == None:
|
||||||
|
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
model_type = conversation_config.model_type
|
||||||
|
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||||
|
async for result in send_llm_response(formatted_help):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
# Adding specification to search online specifically on khoj.dev pages.
|
||||||
|
custom_filters.append("site:khoj.dev")
|
||||||
|
conversation_commands.append(ConversationCommand.Online)
|
||||||
|
|
||||||
|
if ConversationCommand.Automation in conversation_commands:
|
||||||
|
try:
|
||||||
|
automation, crontime, query_to_run, subject = await create_automation(
|
||||||
|
q, timezone, user, request.url, meta_log
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||||
|
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
|
||||||
|
async for result in send_llm_response(error_message):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
|
||||||
|
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
|
||||||
|
await sync_to_async(save_to_conversation_log)(
|
||||||
|
q,
|
||||||
|
llm_response,
|
||||||
|
user,
|
||||||
|
meta_log,
|
||||||
|
user_message_time,
|
||||||
|
intent_type="automation",
|
||||||
|
client_application=request.user.client_app,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
inferred_queries=[query_to_run],
|
||||||
|
automation_id=automation.id,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
)
|
||||||
|
async for result in send_llm_response(llm_response):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
|
||||||
|
# Gather Context
|
||||||
|
## Extract Document References
|
||||||
|
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||||
|
async for result in extract_references_and_questions(
|
||||||
|
request,
|
||||||
|
meta_log,
|
||||||
|
q,
|
||||||
|
(n or 7),
|
||||||
|
d,
|
||||||
|
conversation_id,
|
||||||
|
conversation_commands,
|
||||||
|
location,
|
||||||
|
partial(send_event, ChatEvent.STATUS),
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
):
|
||||||
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
|
else:
|
||||||
|
compiled_references.extend(result[0])
|
||||||
|
inferred_queries.extend(result[1])
|
||||||
|
defiltered_query = result[2]
|
||||||
|
|
||||||
|
if not is_none_or_empty(compiled_references):
|
||||||
|
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
||||||
|
# Strip only leading # from headings
|
||||||
|
headings = headings.replace("#", "")
|
||||||
|
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
online_results: Dict = dict()
|
||||||
|
|
||||||
|
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||||
|
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
|
||||||
|
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
||||||
|
conversation_commands.remove(ConversationCommand.Notes)
|
||||||
|
|
||||||
|
## Gather Online References
|
||||||
|
if ConversationCommand.Online in conversation_commands:
|
||||||
|
try:
|
||||||
|
async for result in search_online(
|
||||||
|
defiltered_query,
|
||||||
|
meta_log,
|
||||||
|
location,
|
||||||
|
user,
|
||||||
|
subscribed,
|
||||||
|
partial(send_event, ChatEvent.STATUS),
|
||||||
|
custom_filters,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
):
|
||||||
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
|
else:
|
||||||
|
online_results = result
|
||||||
|
except ValueError as e:
|
||||||
|
error_message = f"Error searching online: {e}. Attempting to respond without online results"
|
||||||
|
logger.warning(error_message)
|
||||||
|
async for result in send_llm_response(error_message):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
|
||||||
|
## Gather Webpage References
|
||||||
|
if ConversationCommand.Webpage in conversation_commands:
|
||||||
|
try:
|
||||||
|
async for result in read_webpages(
|
||||||
|
defiltered_query,
|
||||||
|
meta_log,
|
||||||
|
location,
|
||||||
|
user,
|
||||||
|
subscribed,
|
||||||
|
partial(send_event, ChatEvent.STATUS),
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
):
|
||||||
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
|
else:
|
||||||
|
direct_web_pages = result
|
||||||
|
webpages = []
|
||||||
|
for query in direct_web_pages:
|
||||||
|
if online_results.get(query):
|
||||||
|
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
|
||||||
|
else:
|
||||||
|
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
|
||||||
|
|
||||||
|
for webpage in direct_web_pages[query]["webpages"]:
|
||||||
|
webpages.append(webpage["link"])
|
||||||
|
async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
|
||||||
|
yield result
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error directly reading webpages: {e}. Attempting to respond without online results",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
## Send Gathered References
|
||||||
|
async for result in send_event(
|
||||||
|
ChatEvent.REFERENCES,
|
||||||
|
{
|
||||||
|
"inferredQueries": inferred_queries,
|
||||||
|
"context": compiled_references,
|
||||||
|
"onlineContext": online_results,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
# Generate Output
|
||||||
|
## Generate Image Output
|
||||||
|
if ConversationCommand.Image in conversation_commands:
|
||||||
|
async for result in text_to_image(
|
||||||
|
q,
|
||||||
|
user,
|
||||||
|
meta_log,
|
||||||
|
location_data=location,
|
||||||
|
references=compiled_references,
|
||||||
|
online_results=online_results,
|
||||||
|
subscribed=subscribed,
|
||||||
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
):
|
||||||
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
|
else:
|
||||||
|
image, status_code, improved_image_prompt, intent_type = result
|
||||||
|
|
||||||
|
if image is None or status_code != 200:
|
||||||
|
content_obj = {
|
||||||
|
"content-type": "application/json",
|
||||||
|
"intentType": intent_type,
|
||||||
|
"detail": improved_image_prompt,
|
||||||
|
"image": image,
|
||||||
|
}
|
||||||
|
async for result in send_llm_response(json.dumps(content_obj)):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
|
||||||
|
await sync_to_async(save_to_conversation_log)(
|
||||||
|
q,
|
||||||
|
image,
|
||||||
|
user,
|
||||||
|
meta_log,
|
||||||
|
user_message_time,
|
||||||
|
intent_type=intent_type,
|
||||||
|
inferred_queries=[improved_image_prompt],
|
||||||
|
client_application=request.user.client_app,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
compiled_references=compiled_references,
|
||||||
|
online_results=online_results,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
)
|
||||||
|
content_obj = {
|
||||||
|
"intentType": intent_type,
|
||||||
|
"inferredQueries": [improved_image_prompt],
|
||||||
|
"image": image,
|
||||||
|
}
|
||||||
|
async for result in send_llm_response(json.dumps(content_obj)):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
|
||||||
|
## Generate Text Output
|
||||||
|
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
|
||||||
|
yield result
|
||||||
|
llm_response, chat_metadata = await agenerate_chat_response(
|
||||||
|
defiltered_query,
|
||||||
|
meta_log,
|
||||||
|
conversation,
|
||||||
|
compiled_references,
|
||||||
|
online_results,
|
||||||
|
inferred_queries,
|
||||||
|
conversation_commands,
|
||||||
|
user,
|
||||||
|
request.user.client_app,
|
||||||
|
conversation_id,
|
||||||
|
location,
|
||||||
|
user_name,
|
||||||
|
uploaded_image_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send Response
|
||||||
|
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
continue_stream = True
|
||||||
|
iterator = AsyncIteratorWrapper(llm_response)
|
||||||
|
async for item in iterator:
|
||||||
|
if item is None:
|
||||||
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||||
|
yield result
|
||||||
|
logger.debug("Finished streaming response")
|
||||||
|
return
|
||||||
|
if not connection_alive or not continue_stream:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
|
||||||
|
yield result
|
||||||
|
except Exception as e:
|
||||||
|
continue_stream = False
|
||||||
|
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
|
||||||
|
|
||||||
|
## Stream Text Response
|
||||||
|
if stream:
|
||||||
|
return StreamingResponse(event_generator(q, image=image), media_type="text/plain")
|
||||||
|
## Non-Streaming Text Response
|
||||||
|
else:
|
||||||
|
response_iterator = event_generator(q, image=image)
|
||||||
|
response_data = await read_chat_stream(response_iterator)
|
||||||
|
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
||||||
|
|
Loading…
Add table
Reference in a new issue