From 4c135ea3164e847ebe413d1397b15877d08949d5 Mon Sep 17 00:00:00 2001 From: sabaimran <65192171+sabaimran@users.noreply.github.com> Date: Sun, 9 Jul 2023 10:12:09 -0700 Subject: [PATCH] Make streaming optional for the /chat endpoint (#287) * Update the /chat endpoint to conditionally support streaming - If streams are enabled, return the threadgenerator as it does currently - If stream is disabled, return a JSON response with the response/compiled references separated out - Correspondingly, update the chat.html UI to use the streamed API, as well as Obsidian - Rename chat/init/ to chat/history * Update khoj.el to use the /history endpoint - Update corresponding unit tests to use stream=true * Remove & from call to /chat for obsidian * Abstract functions out into a helpers.py file and clean up some of the error-catching --- src/interface/emacs/khoj.el | 18 ++- src/interface/obsidian/src/chat_modal.ts | 4 +- src/khoj/interface/web/chat.html | 4 +- src/khoj/routers/api.py | 145 ++++++++++------------- src/khoj/routers/helpers.py | 82 +++++++++++++ tests/test_chat_actors.py | 2 + tests/test_chat_director.py | 22 ++-- 7 files changed, 177 insertions(+), 100 deletions(-) create mode 100644 src/khoj/routers/helpers.py diff --git a/src/interface/emacs/khoj.el b/src/interface/emacs/khoj.el index 79e50055..54857927 100644 --- a/src/interface/emacs/khoj.el +++ b/src/interface/emacs/khoj.el @@ -688,7 +688,7 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE." (defun khoj--load-chat-history (buffer-name) "Load Khoj Chat conversation history into BUFFER-NAME." - (let ((json-response (cdr (assoc 'response (khoj--query-chat-api ""))))) + (let ((json-response (cdr (assoc 'response (khoj--get-chat-history-api))))) (with-current-buffer (get-buffer-create buffer-name) (erase-buffer) (insert "* Khoj Chat\n") @@ -766,7 +766,21 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE." "Send QUERY to Khoj Chat API." (let* ((url-request-method "GET") (encoded-query (url-hexify-string query)) - (query-url (format "%s/api/chat?q=%s&n=%s&client=emacs" khoj-server-url khoj-results-count encoded-query))) + (query-url (format "%s/api/chat?q=%s&n=%s&client=emacs" khoj-server-url encoded-query khoj-results-count))) + (with-temp-buffer + (condition-case ex + (progn + (url-insert-file-contents query-url) + (json-parse-buffer :object-type 'alist)) + ('file-error (cond ((string-match "Internal server error" (nth 2 ex)) + (message "Chat processor not configured. Configure OpenAI API key and restart it. Exception: [%s]" ex)) + (t (message "Chat exception: [%s]" ex)))))))) + + +(defun khoj--get-chat-history-api () + "Send QUERY to Khoj Chat History API." + (let* ((url-request-method "GET") + (query-url (format "%s/api/chat/history?client=emacs" khoj-server-url))) (with-temp-buffer (condition-case ex (progn diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts index 6c7c3e11..66381071 100644 --- a/src/interface/obsidian/src/chat_modal.ts +++ b/src/interface/obsidian/src/chat_modal.ts @@ -140,7 +140,7 @@ export class KhojChatModal extends Modal { async getChatHistory(): Promise { // Get chat history from Khoj backend - let chatUrl = `${this.setting.khojUrl}/api/chat/init?client=obsidian`; + let chatUrl = `${this.setting.khojUrl}/api/chat/history?client=obsidian`; let response = await request(chatUrl); let chatLogs = JSON.parse(response).response; chatLogs.forEach((chatLog: any) => { @@ -157,7 +157,7 @@ export class KhojChatModal extends Modal { // Get chat response from Khoj backend let encodedQuery = encodeURIComponent(query); - let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian`; + let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian&stream=true`; let responseElement = this.createKhojResponseDiv(); // Temporary status message to indicate that Khoj is thinking diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index b4e5555a..d8872b9b 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -63,7 +63,7 @@ document.getElementById("chat-input").value = ""; // Generate backend API URL to execute query - let url = `/api/chat?q=${encodeURIComponent(query)}&n=${results_count}&client=web`; + let url = `/api/chat?q=${encodeURIComponent(query)}&n=${results_count}&client=web&stream=true`; let chat_body = document.getElementById("chat-body"); let new_response = document.createElement("div"); @@ -130,7 +130,7 @@ } window.onload = function () { - fetch('/api/chat/init?client=web') + fetch('/api/chat/history?client=web') .then(response => response.json()) .then(data => { if (data.detail) { diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 1f5bba0c..39593b44 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -4,9 +4,8 @@ import math import time import yaml import logging -from datetime import datetime +import json from typing import List, Optional, Union -from functools import partial # External Packages from fastapi import APIRouter, HTTPException, Header, Request @@ -14,8 +13,6 @@ from sentence_transformers import util # Internal Packages from khoj.configure import configure_processor, configure_search -from khoj.processor.conversation.gpt import converse, extract_questions -from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml from khoj.search_type import image_search, text_search from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter @@ -35,7 +32,11 @@ from khoj.utils.rawconfig import ( from khoj.utils.state import SearchType from khoj.utils import state, constants from khoj.utils.yaml import save_config_to_file_updated_state -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, Response +from khoj.routers.helpers import perform_chat_checks, generate_chat_response +from khoj.processor.conversation.gpt import extract_questions +from fastapi.requests import Request + # Initialize Router api = APIRouter() @@ -408,22 +409,15 @@ def update( return {"status": "ok", "message": "khoj reloaded"} -@api.get("/chat/init") -def chat_init( +@api.get("/chat/history") +def chat_history( request: Request, client: Optional[str] = None, user_agent: Optional[str] = Header(None), referer: Optional[str] = Header(None), host: Optional[str] = Header(None), ): - if ( - state.processor_config is None - or state.processor_config.conversation is None - or state.processor_config.conversation.openai_api_key is None - ): - raise HTTPException( - status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat." - ) + perform_chat_checks() # Load Conversation History meta_log = state.processor_config.conversation.meta_log @@ -444,53 +438,71 @@ def chat_init( return {"status": "ok", "response": meta_log.get("chat", [])} -@api.get("/chat", response_class=StreamingResponse) +@api.get("/chat", response_class=Response) async def chat( request: Request, - q: Optional[str] = None, + q: str, n: Optional[int] = 5, client: Optional[str] = None, + stream: Optional[bool] = False, user_agent: Optional[str] = Header(None), referer: Optional[str] = Header(None), host: Optional[str] = Header(None), -) -> StreamingResponse: - def _save_to_conversation_log( - q: str, - gpt_response: str, - user_message_time: str, - compiled_references: List[str], - inferred_queries: List[str], - meta_log, - ): - state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response]) - state.processor_config.conversation.meta_log["chat"] = message_to_log( - q, - gpt_response, - user_message_metadata={"created": user_message_time}, - khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}}, - conversation_log=meta_log.get("chat", []), - ) +) -> Response: + perform_chat_checks() + compiled_references, inferred_queries = await extract_references_and_questions(request, q, (n or 5)) - if ( - state.processor_config is None - or state.processor_config.conversation is None - or state.processor_config.conversation.openai_api_key is None - ): - raise HTTPException( - status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat." - ) + # Get the (streamed) chat response from GPT. + gpt_response = generate_chat_response( + q, + meta_log=state.processor_config.conversation.meta_log, + compiled_references=compiled_references, + inferred_queries=inferred_queries, + ) + if gpt_response is None: + return Response(content=gpt_response, media_type="text/plain", status_code=500) + if stream: + return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200) + + # Get the full response from the generator if the stream is not requested. + aggregated_gpt_response = "" + while True: + try: + aggregated_gpt_response += next(gpt_response) + except StopIteration: + break + + actual_response = aggregated_gpt_response.split("### compiled references:")[0] + + response_obj = {"response": actual_response, "context": compiled_references} + + user_state = { + "client_host": request.client.host if request.client else None, + "user_agent": user_agent or "unknown", + "referer": referer or "unknown", + "host": host or "unknown", + } + + state.telemetry += [ + log_telemetry( + telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state + ) + ] + + return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200) + + +async def extract_references_and_questions( + request: Request, + q: str, + n: int, +): # Load Conversation History meta_log = state.processor_config.conversation.meta_log - # If user query is empty, return nothing - if not q: - return StreamingResponse(None) - # Initialize Variables api_key = state.processor_config.conversation.openai_api_key - chat_model = state.processor_config.conversation.chat_model - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") conversation_type = "general" if q.startswith("@general") else "notes" compiled_references = [] inferred_queries = [] @@ -509,39 +521,4 @@ async def chat( ) compiled_references = [item.additional["compiled"] for item in result_list] - # Switch to general conversation type if no relevant notes found for the given query - conversation_type = "notes" if compiled_references else "general" - logger.debug(f"Conversation Type: {conversation_type}") - - user_state = { - "client_host": request.client.host if request.client else None, - "user_agent": user_agent or "unknown", - "referer": referer or "unknown", - "host": host or "unknown", - } - - state.telemetry += [ - log_telemetry( - telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state - ) - ] - - try: - with timer("Generating chat response took", logger): - partial_completion = partial( - _save_to_conversation_log, - q, - user_message_time=user_message_time, - compiled_references=compiled_references, - inferred_queries=inferred_queries, - meta_log=meta_log, - ) - - gpt_response = converse( - compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion - ) - - except Exception as e: - gpt_response = str(e) - - return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200) + return compiled_references, inferred_queries diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py new file mode 100644 index 00000000..0b414cbb --- /dev/null +++ b/src/khoj/routers/helpers.py @@ -0,0 +1,82 @@ +from fastapi import HTTPException +import logging +from datetime import datetime +from functools import partial +from typing import List + +from khoj.utils import state +from khoj.utils.helpers import timer +from khoj.processor.conversation.gpt import converse +from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml + + +logger = logging.getLogger(__name__) + + +def perform_chat_checks(): + if ( + state.processor_config is None + or state.processor_config.conversation is None + or state.processor_config.conversation.openai_api_key is None + ): + raise HTTPException( + status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat." + ) + + +def generate_chat_response( + q: str, + meta_log: dict, + compiled_references: List[str] = [], + inferred_queries: List[str] = [], +): + def _save_to_conversation_log( + q: str, + gpt_response: str, + user_message_time: str, + compiled_references: List[str], + inferred_queries: List[str], + meta_log, + ): + state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response]) + state.processor_config.conversation.meta_log["chat"] = message_to_log( + q, + gpt_response, + user_message_metadata={"created": user_message_time}, + khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}}, + conversation_log=meta_log.get("chat", []), + ) + + # Load Conversation History + meta_log = state.processor_config.conversation.meta_log + + # Initialize Variables + api_key = state.processor_config.conversation.openai_api_key + chat_model = state.processor_config.conversation.chat_model + user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + conversation_type = "general" if q.startswith("@general") else "notes" + + # Switch to general conversation type if no relevant notes found for the given query + conversation_type = "notes" if compiled_references else "general" + logger.debug(f"Conversation Type: {conversation_type}") + + try: + with timer("Generating chat response took", logger): + partial_completion = partial( + _save_to_conversation_log, + q, + user_message_time=user_message_time, + compiled_references=compiled_references, + inferred_queries=inferred_queries, + meta_log=meta_log, + ) + + gpt_response = converse( + compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion + ) + + except Exception as e: + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) + + return gpt_response diff --git a/tests/test_chat_actors.py b/tests/test_chat_actors.py index e645fc13..a1f91188 100644 --- a/tests/test_chat_actors.py +++ b/tests/test_chat_actors.py @@ -174,6 +174,8 @@ def test_generate_search_query_with_date_and_context_from_chat_history(): expected_responses = [ ("dt>='2000-04-01'", "dt<'2000-05-01'"), ("dt>='2000-04-01'", "dt<='2000-04-30'"), + ('dt>="2000-04-01"', 'dt<"2000-05-01"'), + ('dt>="2000-04-01"', 'dt<="2000-04-30"'), ] assert len(response) == 1 assert "Masai Mara" in response[0] diff --git a/tests/test_chat_director.py b/tests/test_chat_director.py index 4dbf808d..15977f71 100644 --- a/tests/test_chat_director.py +++ b/tests/test_chat_director.py @@ -40,7 +40,7 @@ def populate_chat_history(message_list): @pytest.mark.chatquality def test_chat_with_no_chat_history_or_retrieved_content(chat_client): # Act - response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"') + response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true') response_message = response.content.decode("utf-8") # Assert @@ -62,7 +62,7 @@ def test_answer_from_chat_history(chat_client): populate_chat_history(message_list) # Act - response = chat_client.get(f'/api/chat?q="What is my name?"') + response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true') response_message = response.content.decode("utf-8") # Assert @@ -157,7 +157,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client): populate_chat_history(message_list) # Act - response = chat_client.get(f'/api/chat?q="Where was I born?"') + response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true') response_message = response.content.decode("utf-8") # Assert @@ -175,7 +175,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client): def test_answer_requires_current_date_awareness(chat_client): "Chat actor should be able to answer questions relative to current date using provided notes" # Act - response = chat_client.get(f'/api/chat?q="Where did I have lunch today?"') + response = chat_client.get(f'/api/chat?q="Where did I have lunch today?"&stream=true') response_message = response.content.decode("utf-8") # Assert @@ -192,7 +192,7 @@ def test_answer_requires_current_date_awareness(chat_client): def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client): "Chat director should be able to answer questions that require date aware aggregation across multiple notes" # Act - response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"') + response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true') response_message = response.content.decode("utf-8") # Assert @@ -212,7 +212,9 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c populate_chat_history(message_list) # Act - response = chat_client.get(f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."') + response = chat_client.get( + f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."&stream=true' + ) response_message = response.content.decode("utf-8") # Assert @@ -229,7 +231,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c @pytest.mark.chatquality def test_ask_for_clarification_if_not_enough_context_in_question(chat_client): # Act - response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"') + response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true') response_message = response.content.decode("utf-8") # Assert @@ -258,7 +260,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client): populate_chat_history(message_list) # Act - response = chat_client.get(f'/api/chat?q="What is my name?"') + response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true') response_message = response.content.decode("utf-8") # Assert @@ -274,11 +276,11 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client): def test_answer_requires_multiple_independent_searches(chat_client): "Chat director should be able to answer by doing multiple independent searches for required information" # Act - response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"') + response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"&stream=true') response_message = response.content.decode("utf-8") # Assert - expected_responses = ["he is older than namita", "xi is older than namita"] + expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"] assert response.status_code == 200 assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( "Expected Xi is older than Namita, but got: " + response_message