mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Separate notes, online context from user message sent to chat models (#950)
Overview --- - Put context into separate user message before sending to chat model. This should improve model response quality and truncation logic in code - Pass online context from chat history to chat model for response. This should improve response speed when previous online context can be reused - Improve format of notes, online context passed to chat models in prompt. This should improve model response quality Details --- The document, online search context are now passed as separate user messages to chat model, instead of being added to the final user message. This will improve - Models ability to differentiate data from user query. That should improve response quality and reduce prompt injection probability - Make truncation logic simpler and more robust When context window hit, can simply pop messages to auto truncate context in order of context, user, assistant message for each conversation turn in history until reach current user query The complex, brittle logic to extract user query from context in last user message isn't required.
This commit is contained in:
commit
3e17ab438a
9 changed files with 80 additions and 65 deletions
|
@ -147,9 +147,7 @@ def converse_anthropic(
|
|||
"""
|
||||
# Initialize Variables
|
||||
current_date = datetime.now()
|
||||
compiled_references = "\n\n".join({f"# {item}" for item in references})
|
||||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_personality.format(
|
||||
|
@ -180,16 +178,16 @@ def converse_anthropic(
|
|||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
|
||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||
conversation_primer = (
|
||||
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
|
||||
)
|
||||
context_message = ""
|
||||
if not is_none_or_empty(compiled_references):
|
||||
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
|
||||
context_message = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n"
|
||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}"
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
user_query,
|
||||
context_message=context_message,
|
||||
conversation_log=conversation_log,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
|
|
|
@ -151,9 +151,7 @@ def converse_gemini(
|
|||
"""
|
||||
# Initialize Variables
|
||||
current_date = datetime.now()
|
||||
compiled_references = "\n\n".join({f"# {item}" for item in references})
|
||||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_personality.format(
|
||||
|
@ -185,16 +183,16 @@ def converse_gemini(
|
|||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
|
||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||
conversation_primer = (
|
||||
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
|
||||
)
|
||||
context_message = ""
|
||||
if not is_none_or_empty(compiled_references):
|
||||
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
|
||||
context_message = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n"
|
||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}"
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
user_query,
|
||||
context_message=context_message,
|
||||
conversation_log=conversation_log,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
|
|
|
@ -153,7 +153,7 @@ def converse_offline(
|
|||
# Initialize Variables
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references})
|
||||
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||
|
||||
current_date = datetime.now()
|
||||
|
||||
|
@ -170,8 +170,6 @@ def converse_offline(
|
|||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
|
||||
if location_data:
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||
|
@ -181,27 +179,31 @@ def converse_offline(
|
|||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references_message):
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
|
||||
if ConversationCommand.Online in conversation_commands:
|
||||
context_message = ""
|
||||
if not is_none_or_empty(compiled_references):
|
||||
context_message += f"{prompts.notes_conversation_offline.format(references=compiled_references)}\n\n"
|
||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||
simplified_online_results = online_results.copy()
|
||||
for result in online_results:
|
||||
if online_results[result].get("webpages"):
|
||||
simplified_online_results[result] = online_results[result]["webpages"]
|
||||
|
||||
conversation_primer = f"{prompts.online_search_conversation_offline.format(online_results=str(simplified_online_results))}\n{conversation_primer}"
|
||||
if not is_none_or_empty(compiled_references_message):
|
||||
conversation_primer = f"{prompts.notes_conversation_offline.format(references=compiled_references_message)}\n\n{conversation_primer}"
|
||||
context_message += (
|
||||
f"{prompts.online_search_conversation_offline.format(online_results=str(simplified_online_results))}"
|
||||
)
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
user_query,
|
||||
system_prompt,
|
||||
conversation_log,
|
||||
context_message=context_message,
|
||||
model_name=model,
|
||||
loaded_model=offline_chat_model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
|
|
|
@ -143,9 +143,7 @@ def converse(
|
|||
"""
|
||||
# Initialize Variables
|
||||
current_date = datetime.now()
|
||||
compiled_references = "\n\n".join({f"# {item['compiled']}" for item in references})
|
||||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_personality.format(
|
||||
|
@ -176,18 +174,18 @@ def converse(
|
|||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
|
||||
if not is_none_or_empty(online_results):
|
||||
conversation_primer = (
|
||||
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
|
||||
)
|
||||
context_message = ""
|
||||
if not is_none_or_empty(compiled_references):
|
||||
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
|
||||
context_message = f"{prompts.notes_conversation.format(references=compiled_references)}\n\n"
|
||||
if not is_none_or_empty(online_results):
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}"
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
user_query,
|
||||
system_prompt,
|
||||
conversation_log,
|
||||
context_message=context_message,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
|
|
|
@ -118,6 +118,7 @@ Use my personal notes and our past conversations to inform your response.
|
|||
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided notes or past conversations.
|
||||
|
||||
User's Notes:
|
||||
-----
|
||||
{references}
|
||||
""".strip()
|
||||
)
|
||||
|
@ -127,6 +128,7 @@ notes_conversation_offline = PromptTemplate.from_template(
|
|||
Use my personal notes and our past conversations to inform your response.
|
||||
|
||||
User's Notes:
|
||||
-----
|
||||
{references}
|
||||
""".strip()
|
||||
)
|
||||
|
@ -328,6 +330,7 @@ Use this up-to-date information from the internet to inform your response.
|
|||
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the online data or past conversations.
|
||||
|
||||
Information from the internet:
|
||||
-----
|
||||
{online_results}
|
||||
""".strip()
|
||||
)
|
||||
|
@ -337,6 +340,7 @@ online_search_conversation_offline = PromptTemplate.from_template(
|
|||
Use this up-to-date information from the internet to inform your response.
|
||||
|
||||
Information from the internet:
|
||||
-----
|
||||
{online_results}
|
||||
""".strip()
|
||||
)
|
||||
|
|
|
@ -18,6 +18,7 @@ from transformers import AutoTokenizer
|
|||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
||||
|
@ -181,8 +182,9 @@ def generate_chatml_messages_with_context(
|
|||
query_images=None,
|
||||
vision_enabled=False,
|
||||
model_type="",
|
||||
context_message="",
|
||||
):
|
||||
"""Generate messages for ChatGPT with context from previous conversation"""
|
||||
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
||||
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||
if not max_prompt_size:
|
||||
if loaded_model:
|
||||
|
@ -196,21 +198,27 @@ def generate_chatml_messages_with_context(
|
|||
# Extract Chat History for Context
|
||||
chatml_messages: List[ChatMessage] = []
|
||||
for chat in conversation_log.get("chat", []):
|
||||
message_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n"
|
||||
message_context = ""
|
||||
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
||||
message_context += chat.get("intent").get("inferred-queries")[0]
|
||||
if not is_none_or_empty(chat.get("context")):
|
||||
references = "\n\n".join(
|
||||
{f"# File: {item['file']}\n## {item['compiled']}\n" for item in chat.get("context") or []}
|
||||
)
|
||||
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
|
||||
if not is_none_or_empty(chat.get("onlineContext")):
|
||||
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
|
||||
if not is_none_or_empty(message_context):
|
||||
reconstructed_context_message = ChatMessage(content=message_context, role="user")
|
||||
chatml_messages.insert(0, reconstructed_context_message)
|
||||
|
||||
role = "user" if chat["by"] == "you" else "assistant"
|
||||
|
||||
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type"):
|
||||
message_content = chat.get("intent").get("inferred-queries")[0] + message_notes
|
||||
else:
|
||||
message_content = chat["message"] + message_notes
|
||||
|
||||
message_content = construct_structured_message(message_content, chat.get("images"), model_type, vision_enabled)
|
||||
message_content = construct_structured_message(chat["message"], chat.get("images"), model_type, vision_enabled)
|
||||
|
||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||
|
||||
chatml_messages.insert(0, reconstructed_message)
|
||||
|
||||
if len(chatml_messages) >= 2 * lookback_turns:
|
||||
if len(chatml_messages) >= 3 * lookback_turns:
|
||||
break
|
||||
|
||||
messages = []
|
||||
|
@ -221,6 +229,8 @@ def generate_chatml_messages_with_context(
|
|||
role="user",
|
||||
)
|
||||
)
|
||||
if not is_none_or_empty(context_message):
|
||||
messages.append(ChatMessage(content=context_message, role="user"))
|
||||
if len(chatml_messages) > 0:
|
||||
messages += chatml_messages
|
||||
if not is_none_or_empty(system_message):
|
||||
|
|
|
@ -112,7 +112,7 @@ def add_files_filter(request: Request, filter: FilesFilterRequest):
|
|||
file_filters = ConversationAdapters.add_files_to_filter(request.user.object, conversation_id, files_filter)
|
||||
return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding file filter {filter.filename}: {e}", exc_info=True)
|
||||
logger.error(f"Error adding file filter {filter.filenames}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
|
||||
|
||||
|
|
|
@ -214,7 +214,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content():
|
|||
(
|
||||
"When was I born?",
|
||||
"You were born on 1st April 1984.",
|
||||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
[{"compiled": "Testatron was born on 1st April 1984 in Testville.", "file": "birth.org"}],
|
||||
),
|
||||
]
|
||||
|
||||
|
@ -415,15 +415,18 @@ def test_ask_for_clarification_if_not_enough_context_in_question():
|
|||
context = [
|
||||
{
|
||||
"compiled": f"""# Ramya
|
||||
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani."""
|
||||
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""",
|
||||
"file": "Family.md",
|
||||
},
|
||||
{
|
||||
"compiled": f"""# Fang
|
||||
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li."""
|
||||
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""",
|
||||
"file": "Family.md",
|
||||
},
|
||||
{
|
||||
"compiled": f"""# Aiyla
|
||||
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet."""
|
||||
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""",
|
||||
"file": "Family.md",
|
||||
},
|
||||
]
|
||||
|
||||
|
@ -608,9 +611,11 @@ async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client, defa
|
|||
),
|
||||
],
|
||||
)
|
||||
async def test_infer_task_scheduling_request(chat_client, user_query, expected_crontime, expected_qs, unexpected_qs):
|
||||
async def test_infer_task_scheduling_request(
|
||||
chat_client, user_query, expected_crontime, expected_qs, unexpected_qs, default_user2
|
||||
):
|
||||
# Act
|
||||
crontime, inferred_query, _ = await schedule_query(user_query, {})
|
||||
crontime, inferred_query, _ = await schedule_query(user_query, {}, default_user2)
|
||||
inferred_query = inferred_query.lower()
|
||||
|
||||
# Assert
|
||||
|
@ -630,7 +635,7 @@ async def test_infer_task_scheduling_request(chat_client, user_query, expected_c
|
|||
"scheduling_query, executing_query, generated_response, expected_should_notify",
|
||||
[
|
||||
(
|
||||
"Notify me if it is going to rain tomorrow?",
|
||||
"Notify me only if it is going to rain tomorrow?",
|
||||
"What's the weather forecast for tomorrow?",
|
||||
"It is sunny and warm tomorrow.",
|
||||
False,
|
||||
|
@ -656,10 +661,10 @@ async def test_infer_task_scheduling_request(chat_client, user_query, expected_c
|
|||
],
|
||||
)
|
||||
def test_decision_on_when_to_notify_scheduled_task_results(
|
||||
chat_client, scheduling_query, executing_query, generated_response, expected_should_notify
|
||||
chat_client, default_user2, scheduling_query, executing_query, generated_response, expected_should_notify
|
||||
):
|
||||
# Act
|
||||
generated_should_notify = should_notify(scheduling_query, executing_query, generated_response)
|
||||
generated_should_notify = should_notify(scheduling_query, executing_query, generated_response, default_user2)
|
||||
|
||||
# Assert
|
||||
assert generated_should_notify == expected_should_notify
|
||||
|
|
|
@ -307,7 +307,7 @@ def test_summarize_one_file(chat_client, default_user2: KhojUser):
|
|||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
query = "/summarize"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response_message = response.json()["response"]
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
|
@ -339,7 +339,7 @@ def test_summarize_extra_text(chat_client, default_user2: KhojUser):
|
|||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
query = "/summarize tell me about Xiu"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response_message = response.json()["response"]
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
|
@ -367,7 +367,7 @@ def test_summarize_multiple_files(chat_client, default_user2: KhojUser):
|
|||
)
|
||||
|
||||
query = "/summarize"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
|
@ -383,7 +383,7 @@ def test_summarize_no_files(chat_client, default_user2: KhojUser):
|
|||
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
|
@ -418,11 +418,11 @@ def test_summarize_different_conversation(chat_client, default_user2: KhojUser):
|
|||
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation2.id})
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation2.id)})
|
||||
response_message_conv2 = response.json()["response"]
|
||||
|
||||
# now make sure that the file filter is still in conversation 1
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation1.id})
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation1.id)})
|
||||
response_message_conv1 = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
|
@ -449,7 +449,7 @@ def test_summarize_nonexistant_file(chat_client, default_user2: KhojUser):
|
|||
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
|
||||
)
|
||||
query = urllib.parse.quote("/summarize")
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response_message = response.json()["response"]
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
@ -481,7 +481,7 @@ def test_summarize_diff_user_file(chat_client, default_user: KhojUser, pdf_confi
|
|||
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
|
|
Loading…
Reference in a new issue