Initial code with chat streaming working (warning: messy code)

This commit is contained in:
sabaimran 2023-07-04 10:14:39 -07:00
parent 89354def9b
commit 8f491d72de
4 changed files with 163 additions and 33 deletions

View file

@ -64,14 +64,47 @@
// Generate backend API URL to execute query
let url = `/api/chat?q=${encodeURIComponent(query)}&client=web`;
// Call specified Khoj API
let chat_body = document.getElementById("chat-body");
let new_response = document.createElement("div");
new_response.classList.add("chat-message", "khoj");
new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
chat_body.appendChild(new_response);
let new_response_text = document.createElement("div");
new_response_text.classList.add("chat-message-text", "khoj");
new_response.appendChild(new_response_text);
// Call specified Khoj API which returns a streamed response of type text/plain
fetch(url)
.then(response => response.json())
.then(data => {
// Render message by Khoj to chat body
console.log(data.response);
renderMessageWithReference(data.response, "khoj", data.context);
.then(response => {
console.log(response);
const reader = response.body.getReader();
const decoder = new TextDecoder();
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
console.log("Stream complete");
return;
}
const chunk = decoder.decode(value, { stream: true });
new_response_text.innerHTML += chunk;
console.log(`Received ${chunk.length} bytes of data`);
console.log(`Chunk: ${chunk}`);
readStream();
});
}
readStream();
});
// fetch(url)
// .then(data => {
// // Render message by Khoj to chat body
// console.log(data.response);
// renderMessageWithReference(data.response, "khoj", data.context);
// });
}
function incrementalChat(event) {
@ -82,7 +115,7 @@
}
window.onload = function () {
fetch('/api/chat?client=web')
fetch('/api/chat/init?client=web')
.then(response => response.json())
.then(data => {
if (data.detail) {

View file

@ -170,12 +170,18 @@ def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo",
# Get Response from GPT
logger.debug(f"Conversation Context for GPT: {messages}")
response = chat_completion_with_backoff(
return chat_completion_with_backoff(
messages=messages,
model_name=model,
temperature=temperature,
openai_api_key=api_key,
)
# Extract, Clean Message from GPT's Response
return response.strip(empty_escape_sequences)
# async for tokens in chat_completion_with_backoff(
# messages=messages,
# model_name=model,
# temperature=temperature,
# openai_api_key=api_key,
# ):
# logger.info(f"Tokens from GPT: {tokens}")
# yield tokens

View file

@ -2,11 +2,19 @@
import os
import logging
from datetime import datetime
from typing import Any, Optional
from uuid import UUID
import asyncio
from threading import Thread
# External Packages
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.schema import ChatMessage
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.callbacks.base import BaseCallbackManager, AsyncCallbackHandler
import openai
import tiktoken
from tenacity import (
@ -20,12 +28,43 @@ from tenacity import (
# Internal Packages
from khoj.utils.helpers import merge_dicts
import queue
logger = logging.getLogger(__name__)
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
class ThreadedGenerator:
def __init__(self):
self.queue = queue.Queue()
def __iter__(self):
return self
def __next__(self):
item = self.queue.get()
if item is StopIteration:
raise item
return item
def send(self, data):
self.queue.put(data)
def close(self):
self.queue.put(StopIteration)
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
def __init__(self, gen: ThreadedGenerator):
super().__init__()
self.gen = gen
def on_llm_new_token(self, token: str, **kwargs) -> Any:
logger.debug(f"New Token: {token}")
self.gen.send(token)
@retry(
retry=(
retry_if_exception_type(openai.error.Timeout)
@ -63,14 +102,28 @@ def completion_with_backoff(**kwargs):
reraise=True,
)
def chat_completion_with_backoff(messages, model_name, temperature, openai_api_key=None):
g = ThreadedGenerator()
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
t.start()
return g
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
callback_handler = StreamingChatCallbackHandler(g)
chat = ChatOpenAI(
streaming=True,
verbose=True,
callback_manager=BaseCallbackManager([callback_handler]),
model_name=model_name,
temperature=temperature,
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
request_timeout=20,
max_retries=1,
)
return chat(messages).content
chat(messages=messages)
g.close()
def generate_chatml_messages_with_context(

View file

@ -34,6 +34,7 @@ from khoj.utils.rawconfig import (
from khoj.utils.state import SearchType
from khoj.utils import state, constants
from khoj.utils.yaml import save_config_to_file_updated_state
from fastapi.responses import StreamingResponse
# Initialize Router
api = APIRouter()
@ -393,8 +394,8 @@ def update(
return {"status": "ok", "message": "khoj reloaded"}
@api.get("/chat")
async def chat(
@api.get("/chat/init")
def chat_init(
request: Request,
q: Optional[str] = None,
client: Optional[str] = None,
@ -411,13 +412,52 @@ async def chat(
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
)
# Load Conversation History
meta_log = state.processor_config.conversation.meta_log
user_state = {
"client_host": request.client.host,
"user_agent": user_agent or "unknown",
"referer": referer or "unknown",
"host": host or "unknown",
}
state.telemetry += [
log_telemetry(
telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state
)
]
# If user query is empty, return chat history
if not q:
return {"status": "ok", "response": meta_log.get("chat", [])}
@api.get("/chat", response_class=StreamingResponse)
async def chat(
request: Request,
q: Optional[str] = None,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
) -> StreamingResponse:
if (
state.processor_config is None
or state.processor_config.conversation is None
or state.processor_config.conversation.openai_api_key is None
):
raise HTTPException(
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
)
# Load Conversation History
chat_session = state.processor_config.conversation.chat_session
meta_log = state.processor_config.conversation.meta_log
# If user query is empty, return chat history
if not q:
return {"status": "ok", "response": meta_log.get("chat", [])}
return StreamingResponse(None)
# Initialize Variables
api_key = state.processor_config.conversation.openai_api_key
@ -446,24 +486,6 @@ async def chat(
conversation_type = "notes" if compiled_references else "general"
logger.debug(f"Conversation Type: {conversation_type}")
try:
with timer("Generating chat response took", logger):
gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key)
status = "ok"
except Exception as e:
gpt_response = str(e)
status = "error"
# Update Conversation History
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log["chat"] = message_to_log(
q,
gpt_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
conversation_log=meta_log.get("chat", []),
)
user_state = {
"client_host": request.client.host,
"user_agent": user_agent or "unknown",
@ -477,4 +499,20 @@ async def chat(
)
]
return {"status": status, "response": gpt_response, "context": compiled_references}
try:
with timer("Generating chat response took", logger):
gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key)
except Exception as e:
gpt_response = str(e)
# Update Conversation History
# state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
# state.processor_config.conversation.meta_log["chat"] = message_to_log(
# q,
# gpt_response,
# user_message_metadata={"created": user_message_time},
# khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
# conversation_log=meta_log.get("chat", []),
# )
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)