diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 0915b180..73a8816c 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -3,6 +3,7 @@ import base64 import json import logging import time +import warnings from datetime import datetime from functools import partial from typing import Dict, Optional @@ -1002,3 +1003,478 @@ async def chat( response_iterator = event_generator(q, image=image) response_data = await read_chat_stream(response_iterator) return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) + + +# Deprecated API. Remove by end of September 2024 +@api_chat.get("") +@requires(["authenticated"]) +async def get_chat( + request: Request, + common: CommonQueryParams, + q: str, + n: int = 7, + d: float = None, + stream: Optional[bool] = False, + title: Optional[str] = None, + conversation_id: Optional[int] = None, + city: Optional[str] = None, + region: Optional[str] = None, + country: Optional[str] = None, + timezone: Optional[str] = None, + image: Optional[str] = None, + rate_limiter_per_minute=Depends( + ApiUserRateLimiter(requests=60, subscribed_requests=60, window=60, slug="chat_minute") + ), + rate_limiter_per_day=Depends( + ApiUserRateLimiter(requests=600, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") + ), +): + # Issue a deprecation warning + warnings.warn( + "The 'get_chat' API endpoint is deprecated. It will be removed by the end of September 2024.", + DeprecationWarning, + stacklevel=2, + ) + + async def event_generator(q: str, image: str): + start_time = time.perf_counter() + ttft = None + chat_metadata: dict = {} + connection_alive = True + user: KhojUser = request.user.object + subscribed: bool = has_required_scope(request, ["premium"]) + event_delimiter = "␃🔚␗" + q = unquote(q) + nonlocal conversation_id + + uploaded_image_url = None + if image: + decoded_string = unquote(image) + base64_data = decoded_string.split(",", 1)[1] + image_bytes = base64.b64decode(base64_data) + webp_image_bytes = convert_image_to_webp(image_bytes) + try: + uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id) + except: + uploaded_image_url = None + + async def send_event(event_type: ChatEvent, data: str | dict): + nonlocal connection_alive, ttft + if not connection_alive or await request.is_disconnected(): + connection_alive = False + logger.warn(f"User {user} disconnected from {common.client} client") + return + try: + if event_type == ChatEvent.END_LLM_RESPONSE: + collect_telemetry() + if event_type == ChatEvent.START_LLM_RESPONSE: + ttft = time.perf_counter() - start_time + if event_type == ChatEvent.MESSAGE: + yield data + elif event_type == ChatEvent.REFERENCES or stream: + yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) + except asyncio.CancelledError as e: + connection_alive = False + logger.warn(f"User {user} disconnected from {common.client} client: {e}") + return + except Exception as e: + connection_alive = False + logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True) + return + finally: + yield event_delimiter + + async def send_llm_response(response: str): + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): + yield result + async for result in send_event(ChatEvent.MESSAGE, response): + yield result + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): + yield result + + def collect_telemetry(): + # Gather chat response telemetry + nonlocal chat_metadata + latency = time.perf_counter() - start_time + cmd_set = set([cmd.value for cmd in conversation_commands]) + chat_metadata = chat_metadata or {} + chat_metadata["conversation_command"] = cmd_set + chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None + chat_metadata["latency"] = f"{latency:.3f}" + chat_metadata["ttft_latency"] = f"{ttft:.3f}" + + logger.info(f"Chat response time to first token: {ttft:.3f} seconds") + logger.info(f"Chat response total time: {latency:.3f} seconds") + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + client=request.user.client_app, + user_agent=request.headers.get("user-agent"), + host=request.headers.get("host"), + metadata=chat_metadata, + ) + + conversation_commands = [get_conversation_command(query=q, any_references=True)] + + conversation = await ConversationAdapters.aget_conversation_by_user( + user, client_application=request.user.client_app, conversation_id=conversation_id, title=title + ) + if not conversation: + async for result in send_llm_response(f"Conversation {conversation_id} not found"): + yield result + return + conversation_id = conversation.id + + await is_ready_to_chat(user) + + user_name = await aget_user_name(user) + location = None + if city or region or country: + location = LocationData(city=city, region=region, country=country) + + if is_query_empty(q): + async for result in send_llm_response("Please ask your query to get started."): + yield result + return + + user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + meta_log = conversation.conversation_log + is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] + + if conversation_commands == [ConversationCommand.Default] or is_automated_task: + conversation_commands = await aget_relevant_information_sources( + q, meta_log, is_automated_task, subscribed=subscribed, uploaded_image_url=uploaded_image_url + ) + conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) + async for result in send_event( + ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}" + ): + yield result + + mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url) + async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): + yield result + if mode not in conversation_commands: + conversation_commands.append(mode) + + for cmd in conversation_commands: + await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) + q = q.replace(f"/{cmd.value}", "").strip() + + used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] + file_filters = conversation.file_filters if conversation else [] + # Skip trying to summarize if + if ( + # summarization intent was inferred + ConversationCommand.Summarize in conversation_commands + # and not triggered via slash command + and not used_slash_summarize + # but we can't actually summarize + and len(file_filters) != 1 + ): + conversation_commands.remove(ConversationCommand.Summarize) + elif ConversationCommand.Summarize in conversation_commands: + response_log = "" + if len(file_filters) == 0: + response_log = "No files selected for summarization. Please add files using the section on the left." + async for result in send_llm_response(response_log): + yield result + elif len(file_filters) > 1: + response_log = "Only one file can be selected for summarization." + async for result in send_llm_response(response_log): + yield result + else: + try: + file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) + if len(file_object) == 0: + response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." + async for result in send_llm_response(response_log): + yield result + return + contextual_data = " ".join([file.raw_text for file in file_object]) + if not q: + q = "Create a general summary of the file" + async for result in send_event( + ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}" + ): + yield result + + response = await extract_relevant_summary( + q, contextual_data, subscribed=subscribed, uploaded_image_url=uploaded_image_url + ) + response_log = str(response) + async for result in send_llm_response(response_log): + yield result + except Exception as e: + response_log = "Error summarizing file." + logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) + async for result in send_llm_response(response_log): + yield result + await sync_to_async(save_to_conversation_log)( + q, + response_log, + user, + meta_log, + user_message_time, + intent_type="summarize", + client_application=request.user.client_app, + conversation_id=conversation_id, + uploaded_image_url=uploaded_image_url, + ) + return + + custom_filters = [] + if conversation_commands == [ConversationCommand.Help]: + if not q: + conversation_config = await ConversationAdapters.aget_user_conversation_config(user) + if conversation_config == None: + conversation_config = await ConversationAdapters.aget_default_conversation_config() + model_type = conversation_config.model_type + formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) + async for result in send_llm_response(formatted_help): + yield result + return + # Adding specification to search online specifically on khoj.dev pages. + custom_filters.append("site:khoj.dev") + conversation_commands.append(ConversationCommand.Online) + + if ConversationCommand.Automation in conversation_commands: + try: + automation, crontime, query_to_run, subject = await create_automation( + q, timezone, user, request.url, meta_log + ) + except Exception as e: + logger.error(f"Error scheduling task {q} for {user.email}: {e}") + error_message = f"Unable to create automation. Ensure the automation doesn't already exist." + async for result in send_llm_response(error_message): + yield result + return + + llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) + await sync_to_async(save_to_conversation_log)( + q, + llm_response, + user, + meta_log, + user_message_time, + intent_type="automation", + client_application=request.user.client_app, + conversation_id=conversation_id, + inferred_queries=[query_to_run], + automation_id=automation.id, + uploaded_image_url=uploaded_image_url, + ) + async for result in send_llm_response(llm_response): + yield result + return + + # Gather Context + ## Extract Document References + compiled_references, inferred_queries, defiltered_query = [], [], None + async for result in extract_references_and_questions( + request, + meta_log, + q, + (n or 7), + d, + conversation_id, + conversation_commands, + location, + partial(send_event, ChatEvent.STATUS), + uploaded_image_url=uploaded_image_url, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] + + if not is_none_or_empty(compiled_references): + headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) + # Strip only leading # from headings + headings = headings.replace("#", "") + async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): + yield result + + online_results: Dict = dict() + + if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): + async for result in send_llm_response(f"{no_entries_found.format()}"): + yield result + return + + if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): + conversation_commands.remove(ConversationCommand.Notes) + + ## Gather Online References + if ConversationCommand.Online in conversation_commands: + try: + async for result in search_online( + defiltered_query, + meta_log, + location, + user, + subscribed, + partial(send_event, ChatEvent.STATUS), + custom_filters, + uploaded_image_url=uploaded_image_url, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + online_results = result + except ValueError as e: + error_message = f"Error searching online: {e}. Attempting to respond without online results" + logger.warning(error_message) + async for result in send_llm_response(error_message): + yield result + return + + ## Gather Webpage References + if ConversationCommand.Webpage in conversation_commands: + try: + async for result in read_webpages( + defiltered_query, + meta_log, + location, + user, + subscribed, + partial(send_event, ChatEvent.STATUS), + uploaded_image_url=uploaded_image_url, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + direct_web_pages = result + webpages = [] + for query in direct_web_pages: + if online_results.get(query): + online_results[query]["webpages"] = direct_web_pages[query]["webpages"] + else: + online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} + + for webpage in direct_web_pages[query]["webpages"]: + webpages.append(webpage["link"]) + async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"): + yield result + except ValueError as e: + logger.warning( + f"Error directly reading webpages: {e}. Attempting to respond without online results", + exc_info=True, + ) + + ## Send Gathered References + async for result in send_event( + ChatEvent.REFERENCES, + { + "inferredQueries": inferred_queries, + "context": compiled_references, + "onlineContext": online_results, + }, + ): + yield result + + # Generate Output + ## Generate Image Output + if ConversationCommand.Image in conversation_commands: + async for result in text_to_image( + q, + user, + meta_log, + location_data=location, + references=compiled_references, + online_results=online_results, + subscribed=subscribed, + send_status_func=partial(send_event, ChatEvent.STATUS), + uploaded_image_url=uploaded_image_url, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + image, status_code, improved_image_prompt, intent_type = result + + if image is None or status_code != 200: + content_obj = { + "content-type": "application/json", + "intentType": intent_type, + "detail": improved_image_prompt, + "image": image, + } + async for result in send_llm_response(json.dumps(content_obj)): + yield result + return + + await sync_to_async(save_to_conversation_log)( + q, + image, + user, + meta_log, + user_message_time, + intent_type=intent_type, + inferred_queries=[improved_image_prompt], + client_application=request.user.client_app, + conversation_id=conversation_id, + compiled_references=compiled_references, + online_results=online_results, + uploaded_image_url=uploaded_image_url, + ) + content_obj = { + "intentType": intent_type, + "inferredQueries": [improved_image_prompt], + "image": image, + } + async for result in send_llm_response(json.dumps(content_obj)): + yield result + return + + ## Generate Text Output + async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): + yield result + llm_response, chat_metadata = await agenerate_chat_response( + defiltered_query, + meta_log, + conversation, + compiled_references, + online_results, + inferred_queries, + conversation_commands, + user, + request.user.client_app, + conversation_id, + location, + user_name, + uploaded_image_url, + ) + + # Send Response + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): + yield result + + continue_stream = True + iterator = AsyncIteratorWrapper(llm_response) + async for item in iterator: + if item is None: + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): + yield result + logger.debug("Finished streaming response") + return + if not connection_alive or not continue_stream: + continue + try: + async for result in send_event(ChatEvent.MESSAGE, f"{item}"): + yield result + except Exception as e: + continue_stream = False + logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") + + ## Stream Text Response + if stream: + return StreamingResponse(event_generator(q, image=image), media_type="text/plain") + ## Non-Streaming Text Response + else: + response_iterator = event_generator(q, image=image) + response_data = await read_chat_stream(response_iterator) + return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)