From a3022b75562cb3c8d2562239de73eacb840ba96a Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 24 Oct 2024 14:26:57 -0700 Subject: [PATCH] Allow Offline Chat model calling functions to save conversation traces --- .../conversation/offline/chat_model.py | 36 +++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 4eafae00..3a2af64a 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -12,11 +12,12 @@ from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( ThreadedGenerator, + commit_conversation_trace, generate_chatml_messages_with_context, ) from khoj.utils import state from khoj.utils.constants import empty_escape_sequences -from khoj.utils.helpers import ConversationCommand, is_none_or_empty +from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) @@ -34,6 +35,7 @@ def extract_questions_offline( max_prompt_size: int = None, temperature: float = 0.7, personality_context: Optional[str] = None, + tracer: dict = {}, ) -> List[str]: """ Infer search queries to retrieve relevant notes to answer user query @@ -94,6 +96,7 @@ def extract_questions_offline( max_prompt_size=max_prompt_size, temperature=temperature, response_type="json_object", + tracer=tracer, ) finally: state.chat_lock.release() @@ -146,6 +149,7 @@ def converse_offline( location_data: LocationData = None, user_name: str = None, agent: Agent = None, + tracer: dict = {}, ) -> Union[ThreadedGenerator, Iterator[str]]: """ Converse with user using Llama @@ -154,6 +158,7 @@ def converse_offline( assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references}) + tracer["chat_model"] = model current_date = datetime.now() @@ -213,13 +218,14 @@ def converse_offline( logger.debug(f"Conversation Context for {model}: {truncated_messages}") g = ThreadedGenerator(references, online_results, completion_func=completion_func) - t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size)) + t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer)) t.start() return g -def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None): +def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}): stop_phrases = ["", "INST]", "Notes:"] + aggregated_response = "" state.chat_lock.acquire() try: @@ -227,7 +233,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True ) for response in response_iterator: - g.send(response["choices"][0]["delta"].get("content", "")) + response_delta = response["choices"][0]["delta"].get("content", "") + aggregated_response += response_delta + g.send(response_delta) + + # Save conversation trace + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) + finally: state.chat_lock.release() g.close() @@ -242,6 +255,7 @@ def send_message_to_model_offline( stop=[], max_prompt_size: int = None, response_type: str = "text", + tracer: dict = {}, ): assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) @@ -249,7 +263,17 @@ def send_message_to_model_offline( response = offline_chat_model.create_chat_completion( messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type} ) + if streaming: return response - else: - return response["choices"][0]["message"].get("content", "") + + response_text = response["choices"][0]["message"].get("content", "") + + # Save conversation trace for non-streaming responses + # Streamed responses need to be saved by the calling function + tracer["chat_model"] = model + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, response_text, tracer) + + return response_text