mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Initial code with chat streaming working (warning: messy code)
This commit is contained in:
parent
89354def9b
commit
8f491d72de
4 changed files with 163 additions and 33 deletions
|
@ -64,14 +64,47 @@
|
|||
// Generate backend API URL to execute query
|
||||
let url = `/api/chat?q=${encodeURIComponent(query)}&client=web`;
|
||||
|
||||
// Call specified Khoj API
|
||||
let chat_body = document.getElementById("chat-body");
|
||||
let new_response = document.createElement("div");
|
||||
new_response.classList.add("chat-message", "khoj");
|
||||
new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
|
||||
chat_body.appendChild(new_response);
|
||||
|
||||
let new_response_text = document.createElement("div");
|
||||
new_response_text.classList.add("chat-message-text", "khoj");
|
||||
new_response.appendChild(new_response_text);
|
||||
|
||||
// Call specified Khoj API which returns a streamed response of type text/plain
|
||||
fetch(url)
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
// Render message by Khoj to chat body
|
||||
console.log(data.response);
|
||||
renderMessageWithReference(data.response, "khoj", data.context);
|
||||
.then(response => {
|
||||
console.log(response);
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
function readStream() {
|
||||
reader.read().then(({ done, value }) => {
|
||||
if (done) {
|
||||
console.log("Stream complete");
|
||||
return;
|
||||
}
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
new_response_text.innerHTML += chunk;
|
||||
console.log(`Received ${chunk.length} bytes of data`);
|
||||
console.log(`Chunk: ${chunk}`);
|
||||
readStream();
|
||||
});
|
||||
}
|
||||
readStream();
|
||||
});
|
||||
|
||||
|
||||
// fetch(url)
|
||||
// .then(data => {
|
||||
// // Render message by Khoj to chat body
|
||||
// console.log(data.response);
|
||||
// renderMessageWithReference(data.response, "khoj", data.context);
|
||||
// });
|
||||
}
|
||||
|
||||
function incrementalChat(event) {
|
||||
|
@ -82,7 +115,7 @@
|
|||
}
|
||||
|
||||
window.onload = function () {
|
||||
fetch('/api/chat?client=web')
|
||||
fetch('/api/chat/init?client=web')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.detail) {
|
||||
|
|
|
@ -170,12 +170,18 @@ def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo",
|
|||
|
||||
# Get Response from GPT
|
||||
logger.debug(f"Conversation Context for GPT: {messages}")
|
||||
response = chat_completion_with_backoff(
|
||||
return chat_completion_with_backoff(
|
||||
messages=messages,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
openai_api_key=api_key,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
return response.strip(empty_escape_sequences)
|
||||
# async for tokens in chat_completion_with_backoff(
|
||||
# messages=messages,
|
||||
# model_name=model,
|
||||
# temperature=temperature,
|
||||
# openai_api_key=api_key,
|
||||
# ):
|
||||
# logger.info(f"Tokens from GPT: {tokens}")
|
||||
# yield tokens
|
||||
|
|
|
@ -2,11 +2,19 @@
|
|||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
import asyncio
|
||||
from threading import Thread
|
||||
|
||||
# External Packages
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.schema import ChatMessage
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.callbacks.base import BaseCallbackManager, AsyncCallbackHandler
|
||||
import openai
|
||||
import tiktoken
|
||||
from tenacity import (
|
||||
|
@ -20,12 +28,43 @@ from tenacity import (
|
|||
|
||||
# Internal Packages
|
||||
from khoj.utils.helpers import merge_dicts
|
||||
import queue
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
|
||||
|
||||
|
||||
class ThreadedGenerator:
|
||||
def __init__(self):
|
||||
self.queue = queue.Queue()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
item = self.queue.get()
|
||||
if item is StopIteration:
|
||||
raise item
|
||||
return item
|
||||
|
||||
def send(self, data):
|
||||
self.queue.put(data)
|
||||
|
||||
def close(self):
|
||||
self.queue.put(StopIteration)
|
||||
|
||||
|
||||
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
def __init__(self, gen: ThreadedGenerator):
|
||||
super().__init__()
|
||||
self.gen = gen
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
||||
logger.debug(f"New Token: {token}")
|
||||
self.gen.send(token)
|
||||
|
||||
|
||||
@retry(
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
|
@ -63,14 +102,28 @@ def completion_with_backoff(**kwargs):
|
|||
reraise=True,
|
||||
)
|
||||
def chat_completion_with_backoff(messages, model_name, temperature, openai_api_key=None):
|
||||
g = ThreadedGenerator()
|
||||
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
||||
callback_handler = StreamingChatCallbackHandler(g)
|
||||
chat = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callback_manager=BaseCallbackManager([callback_handler]),
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||
request_timeout=20,
|
||||
max_retries=1,
|
||||
)
|
||||
return chat(messages).content
|
||||
|
||||
chat(messages=messages)
|
||||
|
||||
g.close()
|
||||
|
||||
|
||||
def generate_chatml_messages_with_context(
|
||||
|
|
|
@ -34,6 +34,7 @@ from khoj.utils.rawconfig import (
|
|||
from khoj.utils.state import SearchType
|
||||
from khoj.utils import state, constants
|
||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
# Initialize Router
|
||||
api = APIRouter()
|
||||
|
@ -393,8 +394,8 @@ def update(
|
|||
return {"status": "ok", "message": "khoj reloaded"}
|
||||
|
||||
|
||||
@api.get("/chat")
|
||||
async def chat(
|
||||
@api.get("/chat/init")
|
||||
def chat_init(
|
||||
request: Request,
|
||||
q: Optional[str] = None,
|
||||
client: Optional[str] = None,
|
||||
|
@ -411,13 +412,52 @@ async def chat(
|
|||
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
||||
)
|
||||
|
||||
# Load Conversation History
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
user_state = {
|
||||
"client_host": request.client.host,
|
||||
"user_agent": user_agent or "unknown",
|
||||
"referer": referer or "unknown",
|
||||
"host": host or "unknown",
|
||||
}
|
||||
|
||||
state.telemetry += [
|
||||
log_telemetry(
|
||||
telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state
|
||||
)
|
||||
]
|
||||
|
||||
# If user query is empty, return chat history
|
||||
if not q:
|
||||
return {"status": "ok", "response": meta_log.get("chat", [])}
|
||||
|
||||
|
||||
@api.get("/chat", response_class=StreamingResponse)
|
||||
async def chat(
|
||||
request: Request,
|
||||
q: Optional[str] = None,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
) -> StreamingResponse:
|
||||
if (
|
||||
state.processor_config is None
|
||||
or state.processor_config.conversation is None
|
||||
or state.processor_config.conversation.openai_api_key is None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
||||
)
|
||||
|
||||
# Load Conversation History
|
||||
chat_session = state.processor_config.conversation.chat_session
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
# If user query is empty, return chat history
|
||||
if not q:
|
||||
return {"status": "ok", "response": meta_log.get("chat", [])}
|
||||
return StreamingResponse(None)
|
||||
|
||||
# Initialize Variables
|
||||
api_key = state.processor_config.conversation.openai_api_key
|
||||
|
@ -446,24 +486,6 @@ async def chat(
|
|||
conversation_type = "notes" if compiled_references else "general"
|
||||
logger.debug(f"Conversation Type: {conversation_type}")
|
||||
|
||||
try:
|
||||
with timer("Generating chat response took", logger):
|
||||
gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key)
|
||||
status = "ok"
|
||||
except Exception as e:
|
||||
gpt_response = str(e)
|
||||
status = "error"
|
||||
|
||||
# Update Conversation History
|
||||
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||
q,
|
||||
gpt_response,
|
||||
user_message_metadata={"created": user_message_time},
|
||||
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
)
|
||||
|
||||
user_state = {
|
||||
"client_host": request.client.host,
|
||||
"user_agent": user_agent or "unknown",
|
||||
|
@ -477,4 +499,20 @@ async def chat(
|
|||
)
|
||||
]
|
||||
|
||||
return {"status": status, "response": gpt_response, "context": compiled_references}
|
||||
try:
|
||||
with timer("Generating chat response took", logger):
|
||||
gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key)
|
||||
except Exception as e:
|
||||
gpt_response = str(e)
|
||||
|
||||
# Update Conversation History
|
||||
# state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
||||
# state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||
# q,
|
||||
# gpt_response,
|
||||
# user_message_metadata={"created": user_message_time},
|
||||
# khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
||||
# conversation_log=meta_log.get("chat", []),
|
||||
# )
|
||||
|
||||
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
|
||||
|
|
Loading…
Reference in a new issue