mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Make streaming optional for the /chat endpoint (#287)
* Update the /chat endpoint to conditionally support streaming - If streams are enabled, return the threadgenerator as it does currently - If stream is disabled, return a JSON response with the response/compiled references separated out - Correspondingly, update the chat.html UI to use the streamed API, as well as Obsidian - Rename chat/init/ to chat/history * Update khoj.el to use the /history endpoint - Update corresponding unit tests to use stream=true * Remove & from call to /chat for obsidian * Abstract functions out into a helpers.py file and clean up some of the error-catching
This commit is contained in:
parent
0a86220d42
commit
4c135ea316
7 changed files with 177 additions and 100 deletions
|
@ -688,7 +688,7 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE."
|
||||||
|
|
||||||
(defun khoj--load-chat-history (buffer-name)
|
(defun khoj--load-chat-history (buffer-name)
|
||||||
"Load Khoj Chat conversation history into BUFFER-NAME."
|
"Load Khoj Chat conversation history into BUFFER-NAME."
|
||||||
(let ((json-response (cdr (assoc 'response (khoj--query-chat-api "")))))
|
(let ((json-response (cdr (assoc 'response (khoj--get-chat-history-api)))))
|
||||||
(with-current-buffer (get-buffer-create buffer-name)
|
(with-current-buffer (get-buffer-create buffer-name)
|
||||||
(erase-buffer)
|
(erase-buffer)
|
||||||
(insert "* Khoj Chat\n")
|
(insert "* Khoj Chat\n")
|
||||||
|
@ -766,7 +766,21 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE."
|
||||||
"Send QUERY to Khoj Chat API."
|
"Send QUERY to Khoj Chat API."
|
||||||
(let* ((url-request-method "GET")
|
(let* ((url-request-method "GET")
|
||||||
(encoded-query (url-hexify-string query))
|
(encoded-query (url-hexify-string query))
|
||||||
(query-url (format "%s/api/chat?q=%s&n=%s&client=emacs" khoj-server-url khoj-results-count encoded-query)))
|
(query-url (format "%s/api/chat?q=%s&n=%s&client=emacs" khoj-server-url encoded-query khoj-results-count)))
|
||||||
|
(with-temp-buffer
|
||||||
|
(condition-case ex
|
||||||
|
(progn
|
||||||
|
(url-insert-file-contents query-url)
|
||||||
|
(json-parse-buffer :object-type 'alist))
|
||||||
|
('file-error (cond ((string-match "Internal server error" (nth 2 ex))
|
||||||
|
(message "Chat processor not configured. Configure OpenAI API key and restart it. Exception: [%s]" ex))
|
||||||
|
(t (message "Chat exception: [%s]" ex))))))))
|
||||||
|
|
||||||
|
|
||||||
|
(defun khoj--get-chat-history-api ()
|
||||||
|
"Send QUERY to Khoj Chat History API."
|
||||||
|
(let* ((url-request-method "GET")
|
||||||
|
(query-url (format "%s/api/chat/history?client=emacs" khoj-server-url)))
|
||||||
(with-temp-buffer
|
(with-temp-buffer
|
||||||
(condition-case ex
|
(condition-case ex
|
||||||
(progn
|
(progn
|
||||||
|
|
|
@ -140,7 +140,7 @@ export class KhojChatModal extends Modal {
|
||||||
|
|
||||||
async getChatHistory(): Promise<void> {
|
async getChatHistory(): Promise<void> {
|
||||||
// Get chat history from Khoj backend
|
// Get chat history from Khoj backend
|
||||||
let chatUrl = `${this.setting.khojUrl}/api/chat/init?client=obsidian`;
|
let chatUrl = `${this.setting.khojUrl}/api/chat/history?client=obsidian`;
|
||||||
let response = await request(chatUrl);
|
let response = await request(chatUrl);
|
||||||
let chatLogs = JSON.parse(response).response;
|
let chatLogs = JSON.parse(response).response;
|
||||||
chatLogs.forEach((chatLog: any) => {
|
chatLogs.forEach((chatLog: any) => {
|
||||||
|
@ -157,7 +157,7 @@ export class KhojChatModal extends Modal {
|
||||||
|
|
||||||
// Get chat response from Khoj backend
|
// Get chat response from Khoj backend
|
||||||
let encodedQuery = encodeURIComponent(query);
|
let encodedQuery = encodeURIComponent(query);
|
||||||
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian`;
|
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian&stream=true`;
|
||||||
let responseElement = this.createKhojResponseDiv();
|
let responseElement = this.createKhojResponseDiv();
|
||||||
|
|
||||||
// Temporary status message to indicate that Khoj is thinking
|
// Temporary status message to indicate that Khoj is thinking
|
||||||
|
|
|
@ -63,7 +63,7 @@
|
||||||
document.getElementById("chat-input").value = "";
|
document.getElementById("chat-input").value = "";
|
||||||
|
|
||||||
// Generate backend API URL to execute query
|
// Generate backend API URL to execute query
|
||||||
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${results_count}&client=web`;
|
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${results_count}&client=web&stream=true`;
|
||||||
|
|
||||||
let chat_body = document.getElementById("chat-body");
|
let chat_body = document.getElementById("chat-body");
|
||||||
let new_response = document.createElement("div");
|
let new_response = document.createElement("div");
|
||||||
|
@ -130,7 +130,7 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
window.onload = function () {
|
window.onload = function () {
|
||||||
fetch('/api/chat/init?client=web')
|
fetch('/api/chat/history?client=web')
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
if (data.detail) {
|
if (data.detail) {
|
||||||
|
|
|
@ -4,9 +4,8 @@ import math
|
||||||
import time
|
import time
|
||||||
import yaml
|
import yaml
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
import json
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import APIRouter, HTTPException, Header, Request
|
from fastapi import APIRouter, HTTPException, Header, Request
|
||||||
|
@ -14,8 +13,6 @@ from sentence_transformers import util
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_processor, configure_search
|
from khoj.configure import configure_processor, configure_search
|
||||||
from khoj.processor.conversation.gpt import converse, extract_questions
|
|
||||||
from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
|
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
|
@ -35,7 +32,11 @@ from khoj.utils.rawconfig import (
|
||||||
from khoj.utils.state import SearchType
|
from khoj.utils.state import SearchType
|
||||||
from khoj.utils import state, constants
|
from khoj.utils import state, constants
|
||||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse, Response
|
||||||
|
from khoj.routers.helpers import perform_chat_checks, generate_chat_response
|
||||||
|
from khoj.processor.conversation.gpt import extract_questions
|
||||||
|
from fastapi.requests import Request
|
||||||
|
|
||||||
|
|
||||||
# Initialize Router
|
# Initialize Router
|
||||||
api = APIRouter()
|
api = APIRouter()
|
||||||
|
@ -408,22 +409,15 @@ def update(
|
||||||
return {"status": "ok", "message": "khoj reloaded"}
|
return {"status": "ok", "message": "khoj reloaded"}
|
||||||
|
|
||||||
|
|
||||||
@api.get("/chat/init")
|
@api.get("/chat/history")
|
||||||
def chat_init(
|
def chat_history(
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
user_agent: Optional[str] = Header(None),
|
user_agent: Optional[str] = Header(None),
|
||||||
referer: Optional[str] = Header(None),
|
referer: Optional[str] = Header(None),
|
||||||
host: Optional[str] = Header(None),
|
host: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
if (
|
perform_chat_checks()
|
||||||
state.processor_config is None
|
|
||||||
or state.processor_config.conversation is None
|
|
||||||
or state.processor_config.conversation.openai_api_key is None
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load Conversation History
|
# Load Conversation History
|
||||||
meta_log = state.processor_config.conversation.meta_log
|
meta_log = state.processor_config.conversation.meta_log
|
||||||
|
@ -444,53 +438,71 @@ def chat_init(
|
||||||
return {"status": "ok", "response": meta_log.get("chat", [])}
|
return {"status": "ok", "response": meta_log.get("chat", [])}
|
||||||
|
|
||||||
|
|
||||||
@api.get("/chat", response_class=StreamingResponse)
|
@api.get("/chat", response_class=Response)
|
||||||
async def chat(
|
async def chat(
|
||||||
request: Request,
|
request: Request,
|
||||||
q: Optional[str] = None,
|
q: str,
|
||||||
n: Optional[int] = 5,
|
n: Optional[int] = 5,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
user_agent: Optional[str] = Header(None),
|
user_agent: Optional[str] = Header(None),
|
||||||
referer: Optional[str] = Header(None),
|
referer: Optional[str] = Header(None),
|
||||||
host: Optional[str] = Header(None),
|
host: Optional[str] = Header(None),
|
||||||
) -> StreamingResponse:
|
) -> Response:
|
||||||
def _save_to_conversation_log(
|
perform_chat_checks()
|
||||||
q: str,
|
compiled_references, inferred_queries = await extract_references_and_questions(request, q, (n or 5))
|
||||||
gpt_response: str,
|
|
||||||
user_message_time: str,
|
|
||||||
compiled_references: List[str],
|
|
||||||
inferred_queries: List[str],
|
|
||||||
meta_log,
|
|
||||||
):
|
|
||||||
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
|
|
||||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
|
||||||
q,
|
|
||||||
gpt_response,
|
|
||||||
user_message_metadata={"created": user_message_time},
|
|
||||||
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
|
||||||
conversation_log=meta_log.get("chat", []),
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
# Get the (streamed) chat response from GPT.
|
||||||
state.processor_config is None
|
gpt_response = generate_chat_response(
|
||||||
or state.processor_config.conversation is None
|
q,
|
||||||
or state.processor_config.conversation.openai_api_key is None
|
meta_log=state.processor_config.conversation.meta_log,
|
||||||
):
|
compiled_references=compiled_references,
|
||||||
raise HTTPException(
|
inferred_queries=inferred_queries,
|
||||||
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
)
|
||||||
)
|
if gpt_response is None:
|
||||||
|
return Response(content=gpt_response, media_type="text/plain", status_code=500)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
|
||||||
|
|
||||||
|
# Get the full response from the generator if the stream is not requested.
|
||||||
|
aggregated_gpt_response = ""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
aggregated_gpt_response += next(gpt_response)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
|
||||||
|
|
||||||
|
response_obj = {"response": actual_response, "context": compiled_references}
|
||||||
|
|
||||||
|
user_state = {
|
||||||
|
"client_host": request.client.host if request.client else None,
|
||||||
|
"user_agent": user_agent or "unknown",
|
||||||
|
"referer": referer or "unknown",
|
||||||
|
"host": host or "unknown",
|
||||||
|
}
|
||||||
|
|
||||||
|
state.telemetry += [
|
||||||
|
log_telemetry(
|
||||||
|
telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_references_and_questions(
|
||||||
|
request: Request,
|
||||||
|
q: str,
|
||||||
|
n: int,
|
||||||
|
):
|
||||||
# Load Conversation History
|
# Load Conversation History
|
||||||
meta_log = state.processor_config.conversation.meta_log
|
meta_log = state.processor_config.conversation.meta_log
|
||||||
|
|
||||||
# If user query is empty, return nothing
|
|
||||||
if not q:
|
|
||||||
return StreamingResponse(None)
|
|
||||||
|
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
api_key = state.processor_config.conversation.openai_api_key
|
api_key = state.processor_config.conversation.openai_api_key
|
||||||
chat_model = state.processor_config.conversation.chat_model
|
|
||||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
conversation_type = "general" if q.startswith("@general") else "notes"
|
conversation_type = "general" if q.startswith("@general") else "notes"
|
||||||
compiled_references = []
|
compiled_references = []
|
||||||
inferred_queries = []
|
inferred_queries = []
|
||||||
|
@ -509,39 +521,4 @@ async def chat(
|
||||||
)
|
)
|
||||||
compiled_references = [item.additional["compiled"] for item in result_list]
|
compiled_references = [item.additional["compiled"] for item in result_list]
|
||||||
|
|
||||||
# Switch to general conversation type if no relevant notes found for the given query
|
return compiled_references, inferred_queries
|
||||||
conversation_type = "notes" if compiled_references else "general"
|
|
||||||
logger.debug(f"Conversation Type: {conversation_type}")
|
|
||||||
|
|
||||||
user_state = {
|
|
||||||
"client_host": request.client.host if request.client else None,
|
|
||||||
"user_agent": user_agent or "unknown",
|
|
||||||
"referer": referer or "unknown",
|
|
||||||
"host": host or "unknown",
|
|
||||||
}
|
|
||||||
|
|
||||||
state.telemetry += [
|
|
||||||
log_telemetry(
|
|
||||||
telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
with timer("Generating chat response took", logger):
|
|
||||||
partial_completion = partial(
|
|
||||||
_save_to_conversation_log,
|
|
||||||
q,
|
|
||||||
user_message_time=user_message_time,
|
|
||||||
compiled_references=compiled_references,
|
|
||||||
inferred_queries=inferred_queries,
|
|
||||||
meta_log=meta_log,
|
|
||||||
)
|
|
||||||
|
|
||||||
gpt_response = converse(
|
|
||||||
compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
gpt_response = str(e)
|
|
||||||
|
|
||||||
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
|
|
||||||
|
|
82
src/khoj/routers/helpers.py
Normal file
82
src/khoj/routers/helpers.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from functools import partial
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from khoj.utils import state
|
||||||
|
from khoj.utils.helpers import timer
|
||||||
|
from khoj.processor.conversation.gpt import converse
|
||||||
|
from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def perform_chat_checks():
|
||||||
|
if (
|
||||||
|
state.processor_config is None
|
||||||
|
or state.processor_config.conversation is None
|
||||||
|
or state.processor_config.conversation.openai_api_key is None
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_chat_response(
|
||||||
|
q: str,
|
||||||
|
meta_log: dict,
|
||||||
|
compiled_references: List[str] = [],
|
||||||
|
inferred_queries: List[str] = [],
|
||||||
|
):
|
||||||
|
def _save_to_conversation_log(
|
||||||
|
q: str,
|
||||||
|
gpt_response: str,
|
||||||
|
user_message_time: str,
|
||||||
|
compiled_references: List[str],
|
||||||
|
inferred_queries: List[str],
|
||||||
|
meta_log,
|
||||||
|
):
|
||||||
|
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
|
||||||
|
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||||
|
q,
|
||||||
|
gpt_response,
|
||||||
|
user_message_metadata={"created": user_message_time},
|
||||||
|
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
||||||
|
conversation_log=meta_log.get("chat", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load Conversation History
|
||||||
|
meta_log = state.processor_config.conversation.meta_log
|
||||||
|
|
||||||
|
# Initialize Variables
|
||||||
|
api_key = state.processor_config.conversation.openai_api_key
|
||||||
|
chat_model = state.processor_config.conversation.chat_model
|
||||||
|
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
conversation_type = "general" if q.startswith("@general") else "notes"
|
||||||
|
|
||||||
|
# Switch to general conversation type if no relevant notes found for the given query
|
||||||
|
conversation_type = "notes" if compiled_references else "general"
|
||||||
|
logger.debug(f"Conversation Type: {conversation_type}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with timer("Generating chat response took", logger):
|
||||||
|
partial_completion = partial(
|
||||||
|
_save_to_conversation_log,
|
||||||
|
q,
|
||||||
|
user_message_time=user_message_time,
|
||||||
|
compiled_references=compiled_references,
|
||||||
|
inferred_queries=inferred_queries,
|
||||||
|
meta_log=meta_log,
|
||||||
|
)
|
||||||
|
|
||||||
|
gpt_response = converse(
|
||||||
|
compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
return gpt_response
|
|
@ -174,6 +174,8 @@ def test_generate_search_query_with_date_and_context_from_chat_history():
|
||||||
expected_responses = [
|
expected_responses = [
|
||||||
("dt>='2000-04-01'", "dt<'2000-05-01'"),
|
("dt>='2000-04-01'", "dt<'2000-05-01'"),
|
||||||
("dt>='2000-04-01'", "dt<='2000-04-30'"),
|
("dt>='2000-04-01'", "dt<='2000-04-30'"),
|
||||||
|
('dt>="2000-04-01"', 'dt<"2000-05-01"'),
|
||||||
|
('dt>="2000-04-01"', 'dt<="2000-04-30"'),
|
||||||
]
|
]
|
||||||
assert len(response) == 1
|
assert len(response) == 1
|
||||||
assert "Masai Mara" in response[0]
|
assert "Masai Mara" in response[0]
|
||||||
|
|
|
@ -40,7 +40,7 @@ def populate_chat_history(message_list):
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"')
|
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -62,7 +62,7 @@ def test_answer_from_chat_history(chat_client):
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="What is my name?"')
|
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -157,7 +157,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -175,7 +175,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
|
||||||
def test_answer_requires_current_date_awareness(chat_client):
|
def test_answer_requires_current_date_awareness(chat_client):
|
||||||
"Chat actor should be able to answer questions relative to current date using provided notes"
|
"Chat actor should be able to answer questions relative to current date using provided notes"
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Where did I have lunch today?"')
|
response = chat_client.get(f'/api/chat?q="Where did I have lunch today?"&stream=true')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -192,7 +192,7 @@ def test_answer_requires_current_date_awareness(chat_client):
|
||||||
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
|
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
|
||||||
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"')
|
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -212,7 +212,9 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."')
|
response = chat_client.get(
|
||||||
|
f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."&stream=true'
|
||||||
|
)
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -229,7 +231,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
|
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"')
|
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -258,7 +260,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="What is my name?"')
|
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -274,11 +276,11 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
|
||||||
def test_answer_requires_multiple_independent_searches(chat_client):
|
def test_answer_requires_multiple_independent_searches(chat_client):
|
||||||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"')
|
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"&stream=true')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["he is older than namita", "xi is older than namita"]
|
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||||
"Expected Xi is older than Namita, but got: " + response_message
|
"Expected Xi is older than Namita, but got: " + response_message
|
||||||
|
|
Loading…
Reference in a new issue