Allow Offline Chat model calling functions to save conversation traces

This commit is contained in:
Debanjum Singh Solanky 2024-10-24 14:26:57 -07:00
parent eb6424f14d
commit a3022b7556

View file

@ -12,11 +12,12 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
generate_chatml_messages_with_context,
)
from khoj.utils import state
from khoj.utils.constants import empty_escape_sequences
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
@ -34,6 +35,7 @@ def extract_questions_offline(
max_prompt_size: int = None,
temperature: float = 0.7,
personality_context: Optional[str] = None,
tracer: dict = {},
) -> List[str]:
"""
Infer search queries to retrieve relevant notes to answer user query
@ -94,6 +96,7 @@ def extract_questions_offline(
max_prompt_size=max_prompt_size,
temperature=temperature,
response_type="json_object",
tracer=tracer,
)
finally:
state.chat_lock.release()
@ -146,6 +149,7 @@ def converse_offline(
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
tracer: dict = {},
) -> Union[ThreadedGenerator, Iterator[str]]:
"""
Converse with user using Llama
@ -154,6 +158,7 @@ def converse_offline(
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references})
tracer["chat_model"] = model
current_date = datetime.now()
@ -213,13 +218,14 @@ def converse_offline(
logger.debug(f"Conversation Context for {model}: {truncated_messages}")
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
t.start()
return g
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
stop_phrases = ["<s>", "INST]", "Notes:"]
aggregated_response = ""
state.chat_lock.acquire()
try:
@ -227,7 +233,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
)
for response in response_iterator:
g.send(response["choices"][0]["delta"].get("content", ""))
response_delta = response["choices"][0]["delta"].get("content", "")
aggregated_response += response_delta
g.send(response_delta)
# Save conversation trace
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
finally:
state.chat_lock.release()
g.close()
@ -242,6 +255,7 @@ def send_message_to_model_offline(
stop=[],
max_prompt_size: int = None,
response_type: str = "text",
tracer: dict = {},
):
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
@ -249,7 +263,17 @@ def send_message_to_model_offline(
response = offline_chat_model.create_chat_completion(
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
)
if streaming:
return response
else:
return response["choices"][0]["message"].get("content", "")
response_text = response["choices"][0]["message"].get("content", "")
# Save conversation trace for non-streaming responses
# Streamed responses need to be saved by the calling function
tracer["chat_model"] = model
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, response_text, tracer)
return response_text