Make streaming optional for the /chat endpoint (#287)

* Update the /chat endpoint to conditionally support streaming

- If streams are enabled, return the threadgenerator as it does currently
- If stream is disabled, return a JSON response with the response/compiled references separated out
- Correspondingly, update the chat.html UI to use the streamed API, as well as Obsidian
- Rename chat/init/ to chat/history

* Update khoj.el to use the /history endpoint

- Update corresponding unit tests to use stream=true

* Remove & from call to /chat for obsidian

* Abstract functions out into a helpers.py file and clean up some of the error-catching
This commit is contained in:
sabaimran 2023-07-09 10:12:09 -07:00 committed by GitHub
parent 0a86220d42
commit 4c135ea316
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 177 additions and 100 deletions

View file

@ -688,7 +688,7 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE."
(defun khoj--load-chat-history (buffer-name)
"Load Khoj Chat conversation history into BUFFER-NAME."
(let ((json-response (cdr (assoc 'response (khoj--query-chat-api "")))))
(let ((json-response (cdr (assoc 'response (khoj--get-chat-history-api)))))
(with-current-buffer (get-buffer-create buffer-name)
(erase-buffer)
(insert "* Khoj Chat\n")
@ -766,7 +766,21 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE."
"Send QUERY to Khoj Chat API."
(let* ((url-request-method "GET")
(encoded-query (url-hexify-string query))
(query-url (format "%s/api/chat?q=%s&n=%s&client=emacs" khoj-server-url khoj-results-count encoded-query)))
(query-url (format "%s/api/chat?q=%s&n=%s&client=emacs" khoj-server-url encoded-query khoj-results-count)))
(with-temp-buffer
(condition-case ex
(progn
(url-insert-file-contents query-url)
(json-parse-buffer :object-type 'alist))
('file-error (cond ((string-match "Internal server error" (nth 2 ex))
(message "Chat processor not configured. Configure OpenAI API key and restart it. Exception: [%s]" ex))
(t (message "Chat exception: [%s]" ex))))))))
(defun khoj--get-chat-history-api ()
"Send QUERY to Khoj Chat History API."
(let* ((url-request-method "GET")
(query-url (format "%s/api/chat/history?client=emacs" khoj-server-url)))
(with-temp-buffer
(condition-case ex
(progn

View file

@ -140,7 +140,7 @@ export class KhojChatModal extends Modal {
async getChatHistory(): Promise<void> {
// Get chat history from Khoj backend
let chatUrl = `${this.setting.khojUrl}/api/chat/init?client=obsidian`;
let chatUrl = `${this.setting.khojUrl}/api/chat/history?client=obsidian`;
let response = await request(chatUrl);
let chatLogs = JSON.parse(response).response;
chatLogs.forEach((chatLog: any) => {
@ -157,7 +157,7 @@ export class KhojChatModal extends Modal {
// Get chat response from Khoj backend
let encodedQuery = encodeURIComponent(query);
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian`;
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian&stream=true`;
let responseElement = this.createKhojResponseDiv();
// Temporary status message to indicate that Khoj is thinking

View file

@ -63,7 +63,7 @@
document.getElementById("chat-input").value = "";
// Generate backend API URL to execute query
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${results_count}&client=web`;
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${results_count}&client=web&stream=true`;
let chat_body = document.getElementById("chat-body");
let new_response = document.createElement("div");
@ -130,7 +130,7 @@
}
window.onload = function () {
fetch('/api/chat/init?client=web')
fetch('/api/chat/history?client=web')
.then(response => response.json())
.then(data => {
if (data.detail) {

View file

@ -4,9 +4,8 @@ import math
import time
import yaml
import logging
from datetime import datetime
import json
from typing import List, Optional, Union
from functools import partial
# External Packages
from fastapi import APIRouter, HTTPException, Header, Request
@ -14,8 +13,6 @@ from sentence_transformers import util
# Internal Packages
from khoj.configure import configure_processor, configure_search
from khoj.processor.conversation.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
@ -35,7 +32,11 @@ from khoj.utils.rawconfig import (
from khoj.utils.state import SearchType
from khoj.utils import state, constants
from khoj.utils.yaml import save_config_to_file_updated_state
from fastapi.responses import StreamingResponse
from fastapi.responses import StreamingResponse, Response
from khoj.routers.helpers import perform_chat_checks, generate_chat_response
from khoj.processor.conversation.gpt import extract_questions
from fastapi.requests import Request
# Initialize Router
api = APIRouter()
@ -408,22 +409,15 @@ def update(
return {"status": "ok", "message": "khoj reloaded"}
@api.get("/chat/init")
def chat_init(
@api.get("/chat/history")
def chat_history(
request: Request,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
if (
state.processor_config is None
or state.processor_config.conversation is None
or state.processor_config.conversation.openai_api_key is None
):
raise HTTPException(
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
)
perform_chat_checks()
# Load Conversation History
meta_log = state.processor_config.conversation.meta_log
@ -444,53 +438,71 @@ def chat_init(
return {"status": "ok", "response": meta_log.get("chat", [])}
@api.get("/chat", response_class=StreamingResponse)
@api.get("/chat", response_class=Response)
async def chat(
request: Request,
q: Optional[str] = None,
q: str,
n: Optional[int] = 5,
client: Optional[str] = None,
stream: Optional[bool] = False,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
) -> StreamingResponse:
def _save_to_conversation_log(
q: str,
gpt_response: str,
user_message_time: str,
compiled_references: List[str],
inferred_queries: List[str],
meta_log,
):
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
state.processor_config.conversation.meta_log["chat"] = message_to_log(
q,
gpt_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
conversation_log=meta_log.get("chat", []),
)
) -> Response:
perform_chat_checks()
compiled_references, inferred_queries = await extract_references_and_questions(request, q, (n or 5))
if (
state.processor_config is None
or state.processor_config.conversation is None
or state.processor_config.conversation.openai_api_key is None
):
raise HTTPException(
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
)
# Get the (streamed) chat response from GPT.
gpt_response = generate_chat_response(
q,
meta_log=state.processor_config.conversation.meta_log,
compiled_references=compiled_references,
inferred_queries=inferred_queries,
)
if gpt_response is None:
return Response(content=gpt_response, media_type="text/plain", status_code=500)
if stream:
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
# Get the full response from the generator if the stream is not requested.
aggregated_gpt_response = ""
while True:
try:
aggregated_gpt_response += next(gpt_response)
except StopIteration:
break
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
response_obj = {"response": actual_response, "context": compiled_references}
user_state = {
"client_host": request.client.host if request.client else None,
"user_agent": user_agent or "unknown",
"referer": referer or "unknown",
"host": host or "unknown",
}
state.telemetry += [
log_telemetry(
telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state
)
]
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
async def extract_references_and_questions(
request: Request,
q: str,
n: int,
):
# Load Conversation History
meta_log = state.processor_config.conversation.meta_log
# If user query is empty, return nothing
if not q:
return StreamingResponse(None)
# Initialize Variables
api_key = state.processor_config.conversation.openai_api_key
chat_model = state.processor_config.conversation.chat_model
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_type = "general" if q.startswith("@general") else "notes"
compiled_references = []
inferred_queries = []
@ -509,39 +521,4 @@ async def chat(
)
compiled_references = [item.additional["compiled"] for item in result_list]
# Switch to general conversation type if no relevant notes found for the given query
conversation_type = "notes" if compiled_references else "general"
logger.debug(f"Conversation Type: {conversation_type}")
user_state = {
"client_host": request.client.host if request.client else None,
"user_agent": user_agent or "unknown",
"referer": referer or "unknown",
"host": host or "unknown",
}
state.telemetry += [
log_telemetry(
telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state
)
]
try:
with timer("Generating chat response took", logger):
partial_completion = partial(
_save_to_conversation_log,
q,
user_message_time=user_message_time,
compiled_references=compiled_references,
inferred_queries=inferred_queries,
meta_log=meta_log,
)
gpt_response = converse(
compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion
)
except Exception as e:
gpt_response = str(e)
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
return compiled_references, inferred_queries

View file

@ -0,0 +1,82 @@
from fastapi import HTTPException
import logging
from datetime import datetime
from functools import partial
from typing import List
from khoj.utils import state
from khoj.utils.helpers import timer
from khoj.processor.conversation.gpt import converse
from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
logger = logging.getLogger(__name__)
def perform_chat_checks():
if (
state.processor_config is None
or state.processor_config.conversation is None
or state.processor_config.conversation.openai_api_key is None
):
raise HTTPException(
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
)
def generate_chat_response(
q: str,
meta_log: dict,
compiled_references: List[str] = [],
inferred_queries: List[str] = [],
):
def _save_to_conversation_log(
q: str,
gpt_response: str,
user_message_time: str,
compiled_references: List[str],
inferred_queries: List[str],
meta_log,
):
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
state.processor_config.conversation.meta_log["chat"] = message_to_log(
q,
gpt_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
conversation_log=meta_log.get("chat", []),
)
# Load Conversation History
meta_log = state.processor_config.conversation.meta_log
# Initialize Variables
api_key = state.processor_config.conversation.openai_api_key
chat_model = state.processor_config.conversation.chat_model
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_type = "general" if q.startswith("@general") else "notes"
# Switch to general conversation type if no relevant notes found for the given query
conversation_type = "notes" if compiled_references else "general"
logger.debug(f"Conversation Type: {conversation_type}")
try:
with timer("Generating chat response took", logger):
partial_completion = partial(
_save_to_conversation_log,
q,
user_message_time=user_message_time,
compiled_references=compiled_references,
inferred_queries=inferred_queries,
meta_log=meta_log,
)
gpt_response = converse(
compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion
)
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))
return gpt_response

View file

@ -174,6 +174,8 @@ def test_generate_search_query_with_date_and_context_from_chat_history():
expected_responses = [
("dt>='2000-04-01'", "dt<'2000-05-01'"),
("dt>='2000-04-01'", "dt<='2000-04-30'"),
('dt>="2000-04-01"', 'dt<"2000-05-01"'),
('dt>="2000-04-01"', 'dt<="2000-04-30"'),
]
assert len(response) == 1
assert "Masai Mara" in response[0]

View file

@ -40,7 +40,7 @@ def populate_chat_history(message_list):
@pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# Act
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"')
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@ -62,7 +62,7 @@ def test_answer_from_chat_history(chat_client):
populate_chat_history(message_list)
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"')
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@ -157,7 +157,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
populate_chat_history(message_list)
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"')
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@ -175,7 +175,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
def test_answer_requires_current_date_awareness(chat_client):
"Chat actor should be able to answer questions relative to current date using provided notes"
# Act
response = chat_client.get(f'/api/chat?q="Where did I have lunch today?"')
response = chat_client.get(f'/api/chat?q="Where did I have lunch today?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@ -192,7 +192,7 @@ def test_answer_requires_current_date_awareness(chat_client):
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
# Act
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"')
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@ -212,7 +212,9 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
populate_chat_history(message_list)
# Act
response = chat_client.get(f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."')
response = chat_client.get(
f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."&stream=true'
)
response_message = response.content.decode("utf-8")
# Assert
@ -229,7 +231,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
# Act
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"')
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@ -258,7 +260,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
populate_chat_history(message_list)
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"')
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@ -274,11 +276,11 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
def test_answer_requires_multiple_independent_searches(chat_client):
"Chat director should be able to answer by doing multiple independent searches for required information"
# Act
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"')
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["he is older than namita", "xi is older than namita"]
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected Xi is older than Namita, but got: " + response_message