Save streamed chat conversations via partial function passed to the ThreadGenerator

This commit is contained in:
sabaimran 2023-07-04 17:33:52 -07:00
parent afd162de01
commit 79b1b1d350
3 changed files with 50 additions and 21 deletions

View file

@ -144,7 +144,15 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1
return json.loads(response.strip(empty_escape_sequences))
def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo", api_key=None, temperature=0.2):
def converse(
references,
user_query,
conversation_log={},
model="gpt-3.5-turbo",
api_key=None,
temperature=0.2,
completion_func=None,
):
"""
Converse with user using OpenAI's ChatGPT
"""
@ -176,6 +184,7 @@ def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo",
model_name=model,
temperature=temperature,
openai_api_key=api_key,
completion_func=completion_func,
)
# async for tokens in chat_completion_with_backoff(

View file

@ -37,9 +37,11 @@ max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
class ThreadedGenerator:
def __init__(self, compiled_references):
def __init__(self, compiled_references, completion_func=None):
self.queue = queue.Queue()
self.compiled_references = compiled_references
self.completion_func = completion_func
self.response = ""
def __iter__(self):
return self
@ -47,10 +49,13 @@ class ThreadedGenerator:
def __next__(self):
item = self.queue.get()
if item is StopIteration:
if self.completion_func:
self.completion_func(gpt_response=self.response)
raise StopIteration
return item
def send(self, data):
self.response += data
self.queue.put(data)
def close(self):
@ -65,7 +70,6 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
self.gen = gen
def on_llm_new_token(self, token: str, **kwargs) -> Any:
logger.debug(f"New Token: {token}")
self.gen.send(token)
@ -105,8 +109,10 @@ def completion_with_backoff(**kwargs):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def chat_completion_with_backoff(messages, compiled_references, model_name, temperature, openai_api_key=None):
g = ThreadedGenerator(compiled_references)
def chat_completion_with_backoff(
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
):
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
t.start()
return g

View file

@ -6,6 +6,7 @@ import yaml
import logging
from datetime import datetime
from typing import List, Optional, Union
from functools import partial
# External Packages
from fastapi import APIRouter, HTTPException, Header, Request
@ -442,6 +443,24 @@ async def chat(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
) -> StreamingResponse:
def _save_to_conversation_log(
q: str,
gpt_response: str,
user_message_time: str,
compiled_references: List[str],
inferred_queries: List[str],
chat_session: str,
meta_log,
):
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log["chat"] = message_to_log(
q,
gpt_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
conversation_log=meta_log.get("chat", []),
)
if (
state.processor_config is None
or state.processor_config.conversation is None
@ -501,26 +520,21 @@ async def chat(
try:
with timer("Generating chat response took", logger):
gpt_response = converse(
compiled_references,
partial_completion = partial(
_save_to_conversation_log,
q,
meta_log,
model=chat_model,
api_key=api_key,
chat_session=chat_session,
user_message_time=user_message_time,
compiled_references=compiled_references,
inferred_queries=inferred_queries,
chat_session=chat_session,
meta_log=meta_log,
)
gpt_response = converse(
compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion
)
except Exception as e:
gpt_response = str(e)
# Update Conversation History
# state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
# state.processor_config.conversation.meta_log["chat"] = message_to_log(
# q,
# gpt_response,
# user_message_metadata={"created": user_message_time},
# khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
# conversation_log=meta_log.get("chat", []),
# )
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)