mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Allow Offline Chat model calling functions to save conversation traces
This commit is contained in:
parent
eb6424f14d
commit
a3022b7556
1 changed files with 30 additions and 6 deletions
|
@ -12,11 +12,12 @@ from khoj.processor.conversation import prompts
|
|||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
generate_chatml_messages_with_context,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -34,6 +35,7 @@ def extract_questions_offline(
|
|||
max_prompt_size: int = None,
|
||||
temperature: float = 0.7,
|
||||
personality_context: Optional[str] = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -94,6 +96,7 @@ def extract_questions_offline(
|
|||
max_prompt_size=max_prompt_size,
|
||||
temperature=temperature,
|
||||
response_type="json_object",
|
||||
tracer=tracer,
|
||||
)
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
|
@ -146,6 +149,7 @@ def converse_offline(
|
|||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||
"""
|
||||
Converse with user using Llama
|
||||
|
@ -154,6 +158,7 @@ def converse_offline(
|
|||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references})
|
||||
tracer["chat_model"] = model
|
||||
|
||||
current_date = datetime.now()
|
||||
|
||||
|
@ -213,13 +218,14 @@ def converse_offline(
|
|||
logger.debug(f"Conversation Context for {model}: {truncated_messages}")
|
||||
|
||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
|
||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
|
||||
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
|
||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||
aggregated_response = ""
|
||||
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
|
@ -227,7 +233,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
|
|||
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
||||
)
|
||||
for response in response_iterator:
|
||||
g.send(response["choices"][0]["delta"].get("content", ""))
|
||||
response_delta = response["choices"][0]["delta"].get("content", "")
|
||||
aggregated_response += response_delta
|
||||
g.send(response_delta)
|
||||
|
||||
# Save conversation trace
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
g.close()
|
||||
|
@ -242,6 +255,7 @@ def send_message_to_model_offline(
|
|||
stop=[],
|
||||
max_prompt_size: int = None,
|
||||
response_type: str = "text",
|
||||
tracer: dict = {},
|
||||
):
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
|
@ -249,7 +263,17 @@ def send_message_to_model_offline(
|
|||
response = offline_chat_model.create_chat_completion(
|
||||
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return response
|
||||
else:
|
||||
return response["choices"][0]["message"].get("content", "")
|
||||
|
||||
response_text = response["choices"][0]["message"].get("content", "")
|
||||
|
||||
# Save conversation trace for non-streaming responses
|
||||
# Streamed responses need to be saved by the calling function
|
||||
tracer["chat_model"] = model
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, response_text, tracer)
|
||||
|
||||
return response_text
|
||||
|
|
Loading…
Reference in a new issue