diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 5fb900c9..826f0fa4 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -142,7 +142,6 @@ def converse_anthropic( """ # Initialize Variables current_date = datetime.now() - conversation_primer = prompts.query_prompt.format(query=user_query) compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references}) if agent and agent.personality: @@ -174,16 +173,16 @@ def converse_anthropic( completion_func(chat_response=prompts.no_online_results_found.format()) return iter([prompts.no_online_results_found.format()]) - if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: - conversation_primer = ( - f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}" - ) + context_message = "" if not is_none_or_empty(compiled_references): - conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}" + context_message = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n" + if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: + context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}" # Setup Prompt with Primer or Conversation History messages = generate_chatml_messages_with_context( - conversation_primer, + user_query, + context_message=context_message, conversation_log=conversation_log, model_name=model, max_prompt_size=max_prompt_size, diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index f7cfad31..4221aeb3 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -139,7 +139,6 @@ def converse_gemini( """ # Initialize Variables current_date = datetime.now() - conversation_primer = prompts.query_prompt.format(query=user_query) compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references}) if agent and agent.personality: @@ -172,16 +171,16 @@ def converse_gemini( completion_func(chat_response=prompts.no_online_results_found.format()) return iter([prompts.no_online_results_found.format()]) - if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: - conversation_primer = ( - f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}" - ) + context_message = "" if not is_none_or_empty(compiled_references): - conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}" + context_message = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n" + if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: + context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}" # Setup Prompt with Primer or Conversation History messages = generate_chatml_messages_with_context( - conversation_primer, + user_query, + context_message=context_message, conversation_log=conversation_log, model_name=model, max_prompt_size=max_prompt_size, diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 293bdacd..2f5045c2 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -143,7 +143,6 @@ def converse( """ # Initialize Variables current_date = datetime.now() - conversation_primer = prompts.query_prompt.format(query=user_query) compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references}) if agent and agent.personality: @@ -175,18 +174,18 @@ def converse( completion_func(chat_response=prompts.no_online_results_found.format()) return iter([prompts.no_online_results_found.format()]) - if not is_none_or_empty(online_results): - conversation_primer = ( - f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}" - ) + context_message = "" if not is_none_or_empty(compiled_references): - conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}" + context_message = f"{prompts.notes_conversation.format(references=compiled_references)}\n\n" + if not is_none_or_empty(online_results): + context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}" # Setup Prompt with Primer or Conversation History messages = generate_chatml_messages_with_context( - conversation_primer, + user_query, system_prompt, conversation_log, + context_message=context_message, model_name=model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 56e9e9db..75f17963 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -12,6 +12,7 @@ from transformers import AutoTokenizer from khoj.database.adapters import ConversationAdapters from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser +from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.utils import state from khoj.utils.helpers import is_none_or_empty, merge_dicts @@ -163,6 +164,7 @@ def generate_chatml_messages_with_context( uploaded_image_url=None, vision_enabled=False, model_type="", + context_message="", ): """Generate messages for ChatGPT with context from previous conversation""" # Set max prompt size from user config or based on pre-configured for model and machine specs @@ -178,24 +180,22 @@ def generate_chatml_messages_with_context( # Extract Chat History for Context chatml_messages: List[ChatMessage] = [] for chat in conversation_log.get("chat", []): - references = "\n\n".join( - {f"# File: {item['file']}\n## {item['compiled']}\n" for item in chat.get("context") or []} - ) - message_notes = f"\n\n Notes:\n{references}" if chat.get("context") else "\n" + if not is_none_or_empty(chat.get("context")): + references = "\n\n".join( + {f"# File: {item['file']}\n## {item['compiled']}\n" for item in chat.get("context") or []} + ) + message_context = f"{prompts.notes_conversation.format(references=references)}\n\n" + reconstructed_context_message = ChatMessage(content=message_context, role="context") + chatml_messages.insert(0, reconstructed_context_message) role = "user" if chat["by"] == "you" else "assistant" - - message_content = chat["message"] + message_notes - message_content = construct_structured_message( - message_content, chat.get("uploadedImageData"), model_type, vision_enabled + chat["message"], chat.get("uploadedImageData"), model_type, vision_enabled ) - reconstructed_message = ChatMessage(content=message_content, role=role) - chatml_messages.insert(0, reconstructed_message) - if len(chatml_messages) >= 2 * lookback_turns: + if len(chatml_messages) >= 3 * lookback_turns: break messages = [] @@ -206,6 +206,8 @@ def generate_chatml_messages_with_context( role="user", ) ) + if not is_none_or_empty(context_message): + messages.append(ChatMessage(content=context_message, role="context")) if len(chatml_messages) > 0: messages += chatml_messages if not is_none_or_empty(system_message): @@ -214,6 +216,11 @@ def generate_chatml_messages_with_context( # Truncate oldest messages from conversation history until under max supported prompt size by model messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name) + # Reset context message role to assistant + for message in messages: + if message.role == "context": + message.role = "user" + # Return message in chronological order return messages[::-1]