mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 19:03:01 +01:00
Put context into separate user message before sending to chat model
The document, online search context are now passed as separate user messages to chat model, instead of being added to the final user message. This will improve - Models ability to differentiate data from user query. That should improve response quality and reduce prompt injection probability - Make truncation logic simpler and more robust When context window hit, can simply pop messages to auto truncate context in order of context, user, assistant message for each conversation turn in history until reach current user query The complex, brittle logic to extract user query from context in last user message isn't required. Marking the context message with assistant role doesn't translate well across chat models. E.g - Gemini can't handle consecutive messages by role = model well - Claude will merge consecutive messages by same role. In current message ordering the context message will result get merged into the previous assistant response. And if move context message after user query. The truncation logic will have to hop and skip while doing deletions - GPT seems to handle consecutive roles of any type fine Using context role = user generalizes better across chat models for now and aligns with previous behavior.
This commit is contained in:
parent
7ac241b766
commit
0c52a1169a
4 changed files with 36 additions and 32 deletions
|
@ -142,7 +142,6 @@ def converse_anthropic(
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
current_date = datetime.now()
|
current_date = datetime.now()
|
||||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
|
||||||
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||||
|
|
||||||
if agent and agent.personality:
|
if agent and agent.personality:
|
||||||
|
@ -174,16 +173,16 @@ def converse_anthropic(
|
||||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||||
return iter([prompts.no_online_results_found.format()])
|
return iter([prompts.no_online_results_found.format()])
|
||||||
|
|
||||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
context_message = ""
|
||||||
conversation_primer = (
|
|
||||||
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
|
|
||||||
)
|
|
||||||
if not is_none_or_empty(compiled_references):
|
if not is_none_or_empty(compiled_references):
|
||||||
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
|
context_message = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n"
|
||||||
|
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||||
|
context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}"
|
||||||
|
|
||||||
# Setup Prompt with Primer or Conversation History
|
# Setup Prompt with Primer or Conversation History
|
||||||
messages = generate_chatml_messages_with_context(
|
messages = generate_chatml_messages_with_context(
|
||||||
conversation_primer,
|
user_query,
|
||||||
|
context_message=context_message,
|
||||||
conversation_log=conversation_log,
|
conversation_log=conversation_log,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
|
|
|
@ -139,7 +139,6 @@ def converse_gemini(
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
current_date = datetime.now()
|
current_date = datetime.now()
|
||||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
|
||||||
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||||
|
|
||||||
if agent and agent.personality:
|
if agent and agent.personality:
|
||||||
|
@ -172,16 +171,16 @@ def converse_gemini(
|
||||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||||
return iter([prompts.no_online_results_found.format()])
|
return iter([prompts.no_online_results_found.format()])
|
||||||
|
|
||||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
context_message = ""
|
||||||
conversation_primer = (
|
|
||||||
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
|
|
||||||
)
|
|
||||||
if not is_none_or_empty(compiled_references):
|
if not is_none_or_empty(compiled_references):
|
||||||
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
|
context_message = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n"
|
||||||
|
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||||
|
context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}"
|
||||||
|
|
||||||
# Setup Prompt with Primer or Conversation History
|
# Setup Prompt with Primer or Conversation History
|
||||||
messages = generate_chatml_messages_with_context(
|
messages = generate_chatml_messages_with_context(
|
||||||
conversation_primer,
|
user_query,
|
||||||
|
context_message=context_message,
|
||||||
conversation_log=conversation_log,
|
conversation_log=conversation_log,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
|
|
|
@ -143,7 +143,6 @@ def converse(
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
current_date = datetime.now()
|
current_date = datetime.now()
|
||||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
|
||||||
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||||
|
|
||||||
if agent and agent.personality:
|
if agent and agent.personality:
|
||||||
|
@ -175,18 +174,18 @@ def converse(
|
||||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||||
return iter([prompts.no_online_results_found.format()])
|
return iter([prompts.no_online_results_found.format()])
|
||||||
|
|
||||||
if not is_none_or_empty(online_results):
|
context_message = ""
|
||||||
conversation_primer = (
|
|
||||||
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
|
|
||||||
)
|
|
||||||
if not is_none_or_empty(compiled_references):
|
if not is_none_or_empty(compiled_references):
|
||||||
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
|
context_message = f"{prompts.notes_conversation.format(references=compiled_references)}\n\n"
|
||||||
|
if not is_none_or_empty(online_results):
|
||||||
|
context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}"
|
||||||
|
|
||||||
# Setup Prompt with Primer or Conversation History
|
# Setup Prompt with Primer or Conversation History
|
||||||
messages = generate_chatml_messages_with_context(
|
messages = generate_chatml_messages_with_context(
|
||||||
conversation_primer,
|
user_query,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
conversation_log,
|
conversation_log,
|
||||||
|
context_message=context_message,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
|
|
|
@ -12,6 +12,7 @@ from transformers import AutoTokenizer
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters
|
from khoj.database.adapters import ConversationAdapters
|
||||||
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
||||||
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
||||||
|
@ -163,6 +164,7 @@ def generate_chatml_messages_with_context(
|
||||||
uploaded_image_url=None,
|
uploaded_image_url=None,
|
||||||
vision_enabled=False,
|
vision_enabled=False,
|
||||||
model_type="",
|
model_type="",
|
||||||
|
context_message="",
|
||||||
):
|
):
|
||||||
"""Generate messages for ChatGPT with context from previous conversation"""
|
"""Generate messages for ChatGPT with context from previous conversation"""
|
||||||
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||||
|
@ -178,24 +180,22 @@ def generate_chatml_messages_with_context(
|
||||||
# Extract Chat History for Context
|
# Extract Chat History for Context
|
||||||
chatml_messages: List[ChatMessage] = []
|
chatml_messages: List[ChatMessage] = []
|
||||||
for chat in conversation_log.get("chat", []):
|
for chat in conversation_log.get("chat", []):
|
||||||
|
if not is_none_or_empty(chat.get("context")):
|
||||||
references = "\n\n".join(
|
references = "\n\n".join(
|
||||||
{f"# File: {item['file']}\n## {item['compiled']}\n" for item in chat.get("context") or []}
|
{f"# File: {item['file']}\n## {item['compiled']}\n" for item in chat.get("context") or []}
|
||||||
)
|
)
|
||||||
message_notes = f"\n\n Notes:\n{references}" if chat.get("context") else "\n"
|
message_context = f"{prompts.notes_conversation.format(references=references)}\n\n"
|
||||||
|
reconstructed_context_message = ChatMessage(content=message_context, role="context")
|
||||||
|
chatml_messages.insert(0, reconstructed_context_message)
|
||||||
|
|
||||||
role = "user" if chat["by"] == "you" else "assistant"
|
role = "user" if chat["by"] == "you" else "assistant"
|
||||||
|
|
||||||
message_content = chat["message"] + message_notes
|
|
||||||
|
|
||||||
message_content = construct_structured_message(
|
message_content = construct_structured_message(
|
||||||
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
|
chat["message"], chat.get("uploadedImageData"), model_type, vision_enabled
|
||||||
)
|
)
|
||||||
|
|
||||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||||
|
|
||||||
chatml_messages.insert(0, reconstructed_message)
|
chatml_messages.insert(0, reconstructed_message)
|
||||||
|
|
||||||
if len(chatml_messages) >= 2 * lookback_turns:
|
if len(chatml_messages) >= 3 * lookback_turns:
|
||||||
break
|
break
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
|
@ -206,6 +206,8 @@ def generate_chatml_messages_with_context(
|
||||||
role="user",
|
role="user",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if not is_none_or_empty(context_message):
|
||||||
|
messages.append(ChatMessage(content=context_message, role="context"))
|
||||||
if len(chatml_messages) > 0:
|
if len(chatml_messages) > 0:
|
||||||
messages += chatml_messages
|
messages += chatml_messages
|
||||||
if not is_none_or_empty(system_message):
|
if not is_none_or_empty(system_message):
|
||||||
|
@ -214,6 +216,11 @@ def generate_chatml_messages_with_context(
|
||||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||||
messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name)
|
messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name)
|
||||||
|
|
||||||
|
# Reset context message role to assistant
|
||||||
|
for message in messages:
|
||||||
|
if message.role == "context":
|
||||||
|
message.role = "user"
|
||||||
|
|
||||||
# Return message in chronological order
|
# Return message in chronological order
|
||||||
return messages[::-1]
|
return messages[::-1]
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue