mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Make streaming optional for the /chat endpoint (#287)
* Update the /chat endpoint to conditionally support streaming - If streams are enabled, return the threadgenerator as it does currently - If stream is disabled, return a JSON response with the response/compiled references separated out - Correspondingly, update the chat.html UI to use the streamed API, as well as Obsidian - Rename chat/init/ to chat/history * Update khoj.el to use the /history endpoint - Update corresponding unit tests to use stream=true * Remove & from call to /chat for obsidian * Abstract functions out into a helpers.py file and clean up some of the error-catching
This commit is contained in:
parent
0a86220d42
commit
4c135ea316
7 changed files with 177 additions and 100 deletions
|
@ -688,7 +688,7 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE."
|
|||
|
||||
(defun khoj--load-chat-history (buffer-name)
|
||||
"Load Khoj Chat conversation history into BUFFER-NAME."
|
||||
(let ((json-response (cdr (assoc 'response (khoj--query-chat-api "")))))
|
||||
(let ((json-response (cdr (assoc 'response (khoj--get-chat-history-api)))))
|
||||
(with-current-buffer (get-buffer-create buffer-name)
|
||||
(erase-buffer)
|
||||
(insert "* Khoj Chat\n")
|
||||
|
@ -766,7 +766,21 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE."
|
|||
"Send QUERY to Khoj Chat API."
|
||||
(let* ((url-request-method "GET")
|
||||
(encoded-query (url-hexify-string query))
|
||||
(query-url (format "%s/api/chat?q=%s&n=%s&client=emacs" khoj-server-url khoj-results-count encoded-query)))
|
||||
(query-url (format "%s/api/chat?q=%s&n=%s&client=emacs" khoj-server-url encoded-query khoj-results-count)))
|
||||
(with-temp-buffer
|
||||
(condition-case ex
|
||||
(progn
|
||||
(url-insert-file-contents query-url)
|
||||
(json-parse-buffer :object-type 'alist))
|
||||
('file-error (cond ((string-match "Internal server error" (nth 2 ex))
|
||||
(message "Chat processor not configured. Configure OpenAI API key and restart it. Exception: [%s]" ex))
|
||||
(t (message "Chat exception: [%s]" ex))))))))
|
||||
|
||||
|
||||
(defun khoj--get-chat-history-api ()
|
||||
"Send QUERY to Khoj Chat History API."
|
||||
(let* ((url-request-method "GET")
|
||||
(query-url (format "%s/api/chat/history?client=emacs" khoj-server-url)))
|
||||
(with-temp-buffer
|
||||
(condition-case ex
|
||||
(progn
|
||||
|
|
|
@ -140,7 +140,7 @@ export class KhojChatModal extends Modal {
|
|||
|
||||
async getChatHistory(): Promise<void> {
|
||||
// Get chat history from Khoj backend
|
||||
let chatUrl = `${this.setting.khojUrl}/api/chat/init?client=obsidian`;
|
||||
let chatUrl = `${this.setting.khojUrl}/api/chat/history?client=obsidian`;
|
||||
let response = await request(chatUrl);
|
||||
let chatLogs = JSON.parse(response).response;
|
||||
chatLogs.forEach((chatLog: any) => {
|
||||
|
@ -157,7 +157,7 @@ export class KhojChatModal extends Modal {
|
|||
|
||||
// Get chat response from Khoj backend
|
||||
let encodedQuery = encodeURIComponent(query);
|
||||
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian`;
|
||||
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian&stream=true`;
|
||||
let responseElement = this.createKhojResponseDiv();
|
||||
|
||||
// Temporary status message to indicate that Khoj is thinking
|
||||
|
|
|
@ -63,7 +63,7 @@
|
|||
document.getElementById("chat-input").value = "";
|
||||
|
||||
// Generate backend API URL to execute query
|
||||
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${results_count}&client=web`;
|
||||
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${results_count}&client=web&stream=true`;
|
||||
|
||||
let chat_body = document.getElementById("chat-body");
|
||||
let new_response = document.createElement("div");
|
||||
|
@ -130,7 +130,7 @@
|
|||
}
|
||||
|
||||
window.onload = function () {
|
||||
fetch('/api/chat/init?client=web')
|
||||
fetch('/api/chat/history?client=web')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.detail) {
|
||||
|
|
|
@ -4,9 +4,8 @@ import math
|
|||
import time
|
||||
import yaml
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import json
|
||||
from typing import List, Optional, Union
|
||||
from functools import partial
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter, HTTPException, Header, Request
|
||||
|
@ -14,8 +13,6 @@ from sentence_transformers import util
|
|||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_processor, configure_search
|
||||
from khoj.processor.conversation.gpt import converse, extract_questions
|
||||
from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
|
@ -35,7 +32,11 @@ from khoj.utils.rawconfig import (
|
|||
from khoj.utils.state import SearchType
|
||||
from khoj.utils import state, constants
|
||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from khoj.routers.helpers import perform_chat_checks, generate_chat_response
|
||||
from khoj.processor.conversation.gpt import extract_questions
|
||||
from fastapi.requests import Request
|
||||
|
||||
|
||||
# Initialize Router
|
||||
api = APIRouter()
|
||||
|
@ -408,22 +409,15 @@ def update(
|
|||
return {"status": "ok", "message": "khoj reloaded"}
|
||||
|
||||
|
||||
@api.get("/chat/init")
|
||||
def chat_init(
|
||||
@api.get("/chat/history")
|
||||
def chat_history(
|
||||
request: Request,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
):
|
||||
if (
|
||||
state.processor_config is None
|
||||
or state.processor_config.conversation is None
|
||||
or state.processor_config.conversation.openai_api_key is None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
||||
)
|
||||
perform_chat_checks()
|
||||
|
||||
# Load Conversation History
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
@ -444,53 +438,71 @@ def chat_init(
|
|||
return {"status": "ok", "response": meta_log.get("chat", [])}
|
||||
|
||||
|
||||
@api.get("/chat", response_class=StreamingResponse)
|
||||
@api.get("/chat", response_class=Response)
|
||||
async def chat(
|
||||
request: Request,
|
||||
q: Optional[str] = None,
|
||||
q: str,
|
||||
n: Optional[int] = 5,
|
||||
client: Optional[str] = None,
|
||||
stream: Optional[bool] = False,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
) -> StreamingResponse:
|
||||
def _save_to_conversation_log(
|
||||
q: str,
|
||||
gpt_response: str,
|
||||
user_message_time: str,
|
||||
compiled_references: List[str],
|
||||
inferred_queries: List[str],
|
||||
meta_log,
|
||||
):
|
||||
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
|
||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||
q,
|
||||
gpt_response,
|
||||
user_message_metadata={"created": user_message_time},
|
||||
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
)
|
||||
) -> Response:
|
||||
perform_chat_checks()
|
||||
compiled_references, inferred_queries = await extract_references_and_questions(request, q, (n or 5))
|
||||
|
||||
if (
|
||||
state.processor_config is None
|
||||
or state.processor_config.conversation is None
|
||||
or state.processor_config.conversation.openai_api_key is None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
||||
)
|
||||
# Get the (streamed) chat response from GPT.
|
||||
gpt_response = generate_chat_response(
|
||||
q,
|
||||
meta_log=state.processor_config.conversation.meta_log,
|
||||
compiled_references=compiled_references,
|
||||
inferred_queries=inferred_queries,
|
||||
)
|
||||
if gpt_response is None:
|
||||
return Response(content=gpt_response, media_type="text/plain", status_code=500)
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
|
||||
|
||||
# Get the full response from the generator if the stream is not requested.
|
||||
aggregated_gpt_response = ""
|
||||
while True:
|
||||
try:
|
||||
aggregated_gpt_response += next(gpt_response)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
|
||||
|
||||
response_obj = {"response": actual_response, "context": compiled_references}
|
||||
|
||||
user_state = {
|
||||
"client_host": request.client.host if request.client else None,
|
||||
"user_agent": user_agent or "unknown",
|
||||
"referer": referer or "unknown",
|
||||
"host": host or "unknown",
|
||||
}
|
||||
|
||||
state.telemetry += [
|
||||
log_telemetry(
|
||||
telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state
|
||||
)
|
||||
]
|
||||
|
||||
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
async def extract_references_and_questions(
|
||||
request: Request,
|
||||
q: str,
|
||||
n: int,
|
||||
):
|
||||
# Load Conversation History
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
# If user query is empty, return nothing
|
||||
if not q:
|
||||
return StreamingResponse(None)
|
||||
|
||||
# Initialize Variables
|
||||
api_key = state.processor_config.conversation.openai_api_key
|
||||
chat_model = state.processor_config.conversation.chat_model
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
conversation_type = "general" if q.startswith("@general") else "notes"
|
||||
compiled_references = []
|
||||
inferred_queries = []
|
||||
|
@ -509,39 +521,4 @@ async def chat(
|
|||
)
|
||||
compiled_references = [item.additional["compiled"] for item in result_list]
|
||||
|
||||
# Switch to general conversation type if no relevant notes found for the given query
|
||||
conversation_type = "notes" if compiled_references else "general"
|
||||
logger.debug(f"Conversation Type: {conversation_type}")
|
||||
|
||||
user_state = {
|
||||
"client_host": request.client.host if request.client else None,
|
||||
"user_agent": user_agent or "unknown",
|
||||
"referer": referer or "unknown",
|
||||
"host": host or "unknown",
|
||||
}
|
||||
|
||||
state.telemetry += [
|
||||
log_telemetry(
|
||||
telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
with timer("Generating chat response took", logger):
|
||||
partial_completion = partial(
|
||||
_save_to_conversation_log,
|
||||
q,
|
||||
user_message_time=user_message_time,
|
||||
compiled_references=compiled_references,
|
||||
inferred_queries=inferred_queries,
|
||||
meta_log=meta_log,
|
||||
)
|
||||
|
||||
gpt_response = converse(
|
||||
compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
gpt_response = str(e)
|
||||
|
||||
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
|
||||
return compiled_references, inferred_queries
|
||||
|
|
82
src/khoj/routers/helpers.py
Normal file
82
src/khoj/routers/helpers.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
from fastapi import HTTPException
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.processor.conversation.gpt import converse
|
||||
from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def perform_chat_checks():
|
||||
if (
|
||||
state.processor_config is None
|
||||
or state.processor_config.conversation is None
|
||||
or state.processor_config.conversation.openai_api_key is None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
||||
)
|
||||
|
||||
|
||||
def generate_chat_response(
|
||||
q: str,
|
||||
meta_log: dict,
|
||||
compiled_references: List[str] = [],
|
||||
inferred_queries: List[str] = [],
|
||||
):
|
||||
def _save_to_conversation_log(
|
||||
q: str,
|
||||
gpt_response: str,
|
||||
user_message_time: str,
|
||||
compiled_references: List[str],
|
||||
inferred_queries: List[str],
|
||||
meta_log,
|
||||
):
|
||||
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
|
||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||
q,
|
||||
gpt_response,
|
||||
user_message_metadata={"created": user_message_time},
|
||||
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
)
|
||||
|
||||
# Load Conversation History
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
# Initialize Variables
|
||||
api_key = state.processor_config.conversation.openai_api_key
|
||||
chat_model = state.processor_config.conversation.chat_model
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
conversation_type = "general" if q.startswith("@general") else "notes"
|
||||
|
||||
# Switch to general conversation type if no relevant notes found for the given query
|
||||
conversation_type = "notes" if compiled_references else "general"
|
||||
logger.debug(f"Conversation Type: {conversation_type}")
|
||||
|
||||
try:
|
||||
with timer("Generating chat response took", logger):
|
||||
partial_completion = partial(
|
||||
_save_to_conversation_log,
|
||||
q,
|
||||
user_message_time=user_message_time,
|
||||
compiled_references=compiled_references,
|
||||
inferred_queries=inferred_queries,
|
||||
meta_log=meta_log,
|
||||
)
|
||||
|
||||
gpt_response = converse(
|
||||
compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return gpt_response
|
|
@ -174,6 +174,8 @@ def test_generate_search_query_with_date_and_context_from_chat_history():
|
|||
expected_responses = [
|
||||
("dt>='2000-04-01'", "dt<'2000-05-01'"),
|
||||
("dt>='2000-04-01'", "dt<='2000-04-30'"),
|
||||
('dt>="2000-04-01"', 'dt<"2000-05-01"'),
|
||||
('dt>="2000-04-01"', 'dt<="2000-04-30"'),
|
||||
]
|
||||
assert len(response) == 1
|
||||
assert "Masai Mara" in response[0]
|
||||
|
|
|
@ -40,7 +40,7 @@ def populate_chat_history(message_list):
|
|||
@pytest.mark.chatquality
|
||||
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"')
|
||||
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
|
@ -62,7 +62,7 @@ def test_answer_from_chat_history(chat_client):
|
|||
populate_chat_history(message_list)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="What is my name?"')
|
||||
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
|
@ -157,7 +157,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
|
|||
populate_chat_history(message_list)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
||||
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
|
@ -175,7 +175,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
|
|||
def test_answer_requires_current_date_awareness(chat_client):
|
||||
"Chat actor should be able to answer questions relative to current date using provided notes"
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where did I have lunch today?"')
|
||||
response = chat_client.get(f'/api/chat?q="Where did I have lunch today?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
|
@ -192,7 +192,7 @@ def test_answer_requires_current_date_awareness(chat_client):
|
|||
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
|
||||
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"')
|
||||
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
|
@ -212,7 +212,9 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
|||
populate_chat_history(message_list)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."')
|
||||
response = chat_client.get(
|
||||
f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."&stream=true'
|
||||
)
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
|
@ -229,7 +231,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
|||
@pytest.mark.chatquality
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"')
|
||||
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
|
@ -258,7 +260,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
|
|||
populate_chat_history(message_list)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="What is my name?"')
|
||||
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
|
@ -274,11 +276,11 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
|
|||
def test_answer_requires_multiple_independent_searches(chat_client):
|
||||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"')
|
||||
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["he is older than namita", "xi is older than namita"]
|
||||
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected Xi is older than Namita, but got: " + response_message
|
||||
|
|
Loading…
Reference in a new issue