diff --git a/src/interface/web/app/components/chatInputArea/chatInputArea.tsx b/src/interface/web/app/components/chatInputArea/chatInputArea.tsx index 7f2baf1d..92a3b3ae 100644 --- a/src/interface/web/app/components/chatInputArea/chatInputArea.tsx +++ b/src/interface/web/app/components/chatInputArea/chatInputArea.tsx @@ -367,6 +367,11 @@ export const ChatInputArea = forwardRef((pr e.preventDefault()} className={`${props.isMobileWidth ? "w-[100vw]" : "w-full"} rounded-md`} + side="top" + align="center" + /* Offset below text area on home page (i.e where conversationId is unset) */ + sideOffset={props.conversationId ? 0 : 80} + alignOffset={0} > openAgentEditCard(agents[index].slug) } diff --git a/src/khoj/database/management/commands/change_default_model.py b/src/khoj/database/management/commands/change_default_model.py index cfa78581..d9a6359f 100644 --- a/src/khoj/database/management/commands/change_default_model.py +++ b/src/khoj/database/management/commands/change_default_model.py @@ -19,6 +19,8 @@ from khoj.processor.embeddings import EmbeddingsModel logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +BATCH_SIZE = 1000 # Define an appropriate batch size + class Command(BaseCommand): help = "Convert all existing Entry objects to use a new default Search model." @@ -42,22 +44,24 @@ class Command(BaseCommand): def handle(self, *args, **options): @transaction.atomic def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig): - entries = Entry.objects.filter(entry_filter).all() - compiled_entries = [entry.compiled for entry in entries] - updated_entries: List[Entry] = [] - try: - embeddings = embeddings_model.embed_documents(compiled_entries) + total_entries = Entry.objects.filter(entry_filter).count() + for start in tqdm(range(0, total_entries, BATCH_SIZE)): + end = start + BATCH_SIZE + entries = Entry.objects.filter(entry_filter)[start:end] + compiled_entries = [entry.compiled for entry in entries] + updated_entries: List[Entry] = [] + try: + embeddings = embeddings_model.embed_documents(compiled_entries) + except Exception as e: + logger.error(f"Error embedding documents: {e}") + return - except Exception as e: - logger.error(f"Error embedding documents: {e}") - return + for i, entry in enumerate(entries): + entry.embeddings = embeddings[i] + entry.search_model_id = search_model.id + updated_entries.append(entry) - for i, entry in enumerate(tqdm(entries)): - entry.embeddings = embeddings[i] - entry.search_model_id = search_model.id - updated_entries.append(entry) - - Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"]) + Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"]) search_model_config_id = options.get("search_model_id") apply = options.get("apply") diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index a435f343..268e21aa 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -151,9 +151,7 @@ def converse_anthropic( """ # Initialize Variables current_date = datetime.now() - compiled_references = "\n\n".join({f"# {item}" for item in references}) - - conversation_primer = prompts.query_prompt.format(query=user_query) + compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references}) if agent and agent.personality: system_prompt = prompts.custom_personality.format( @@ -184,16 +182,16 @@ def converse_anthropic( completion_func(chat_response=prompts.no_online_results_found.format()) return iter([prompts.no_online_results_found.format()]) - if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: - conversation_primer = ( - f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}" - ) + context_message = "" if not is_none_or_empty(compiled_references): - conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}" + context_message = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n" + if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: + context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}" # Setup Prompt with Primer or Conversation History messages = generate_chatml_messages_with_context( - conversation_primer, + user_query, + context_message=context_message, conversation_log=conversation_log, model_name=model, max_prompt_size=max_prompt_size, diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 4ff51c5e..ae33d40d 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -156,9 +156,7 @@ def converse_gemini( """ # Initialize Variables current_date = datetime.now() - compiled_references = "\n\n".join({f"# {item}" for item in references}) - - conversation_primer = prompts.query_prompt.format(query=user_query) + compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references}) if agent and agent.personality: system_prompt = prompts.custom_personality.format( @@ -190,16 +188,16 @@ def converse_gemini( completion_func(chat_response=prompts.no_online_results_found.format()) return iter([prompts.no_online_results_found.format()]) - if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: - conversation_primer = ( - f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}" - ) + context_message = "" if not is_none_or_empty(compiled_references): - conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}" + context_message = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n" + if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: + context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}" # Setup Prompt with Primer or Conversation History messages = generate_chatml_messages_with_context( - conversation_primer, + user_query, + context_message=context_message, conversation_log=conversation_log, model_name=model, max_prompt_size=max_prompt_size, diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 3a2af64a..2d2354ed 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -157,9 +157,9 @@ def converse_offline( # Initialize Variables assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) - compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references}) tracer["chat_model"] = model + compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references}) current_date = datetime.now() if agent and agent.personality: @@ -175,8 +175,6 @@ def converse_offline( day_of_week=current_date.strftime("%A"), ) - conversation_primer = prompts.query_prompt.format(query=user_query) - if location_data: location_prompt = prompts.user_location.format(location=f"{location_data}") system_prompt = f"{system_prompt}\n{location_prompt}" @@ -186,27 +184,31 @@ def converse_offline( system_prompt = f"{system_prompt}\n{user_name_prompt}" # Get Conversation Primer appropriate to Conversation Type - if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references_message): + if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references): return iter([prompts.no_notes_found.format()]) elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): completion_func(chat_response=prompts.no_online_results_found.format()) return iter([prompts.no_online_results_found.format()]) - if ConversationCommand.Online in conversation_commands: + context_message = "" + if not is_none_or_empty(compiled_references): + context_message += f"{prompts.notes_conversation_offline.format(references=compiled_references)}\n\n" + if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: simplified_online_results = online_results.copy() for result in online_results: if online_results[result].get("webpages"): simplified_online_results[result] = online_results[result]["webpages"] - conversation_primer = f"{prompts.online_search_conversation_offline.format(online_results=str(simplified_online_results))}\n{conversation_primer}" - if not is_none_or_empty(compiled_references_message): - conversation_primer = f"{prompts.notes_conversation_offline.format(references=compiled_references_message)}\n\n{conversation_primer}" + context_message += ( + f"{prompts.online_search_conversation_offline.format(online_results=str(simplified_online_results))}" + ) # Setup Prompt with Primer or Conversation History messages = generate_chatml_messages_with_context( - conversation_primer, + user_query, system_prompt, conversation_log, + context_message=context_message, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size, diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 0d513268..3c4552d9 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -154,9 +154,7 @@ def converse( """ # Initialize Variables current_date = datetime.now() - compiled_references = "\n\n".join({f"# {item['compiled']}" for item in references}) - - conversation_primer = prompts.query_prompt.format(query=user_query) + compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references}) if agent and agent.personality: system_prompt = prompts.custom_personality.format( @@ -187,18 +185,18 @@ def converse( completion_func(chat_response=prompts.no_online_results_found.format()) return iter([prompts.no_online_results_found.format()]) - if not is_none_or_empty(online_results): - conversation_primer = ( - f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}" - ) + context_message = "" if not is_none_or_empty(compiled_references): - conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}" + context_message = f"{prompts.notes_conversation.format(references=compiled_references)}\n\n" + if not is_none_or_empty(online_results): + context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}" # Setup Prompt with Primer or Conversation History messages = generate_chatml_messages_with_context( - conversation_primer, + user_query, system_prompt, conversation_log, + context_message=context_message, model_name=model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 7988cc43..0cca1e37 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -118,6 +118,7 @@ Use my personal notes and our past conversations to inform your response. Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided notes or past conversations. User's Notes: +----- {references} """.strip() ) @@ -127,6 +128,7 @@ notes_conversation_offline = PromptTemplate.from_template( Use my personal notes and our past conversations to inform your response. User's Notes: +----- {references} """.strip() ) @@ -328,6 +330,7 @@ Use this up-to-date information from the internet to inform your response. Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the online data or past conversations. Information from the internet: +----- {online_results} """.strip() ) @@ -337,6 +340,7 @@ online_search_conversation_offline = PromptTemplate.from_template( Use this up-to-date information from the internet to inform your response. Information from the internet: +----- {online_results} """.strip() ) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 184de372..bc7a7858 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -21,6 +21,7 @@ from transformers import AutoTokenizer from khoj.database.adapters import ConversationAdapters from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser +from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.utils import state from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts @@ -188,8 +189,9 @@ def generate_chatml_messages_with_context( query_images=None, vision_enabled=False, model_type="", + context_message="", ): - """Generate messages for ChatGPT with context from previous conversation""" + """Generate chat messages with appropriate context from previous conversation to send to the chat model""" # Set max prompt size from user config or based on pre-configured for model and machine specs if not max_prompt_size: if loaded_model: @@ -203,21 +205,27 @@ def generate_chatml_messages_with_context( # Extract Chat History for Context chatml_messages: List[ChatMessage] = [] for chat in conversation_log.get("chat", []): - message_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n" + message_context = "" + if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""): + message_context += chat.get("intent").get("inferred-queries")[0] + if not is_none_or_empty(chat.get("context")): + references = "\n\n".join( + {f"# File: {item['file']}\n## {item['compiled']}\n" for item in chat.get("context") or []} + ) + message_context += f"{prompts.notes_conversation.format(references=references)}\n\n" + if not is_none_or_empty(chat.get("onlineContext")): + message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" + if not is_none_or_empty(message_context): + reconstructed_context_message = ChatMessage(content=message_context, role="user") + chatml_messages.insert(0, reconstructed_context_message) + role = "user" if chat["by"] == "you" else "assistant" - - if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type"): - message_content = chat.get("intent").get("inferred-queries")[0] + message_notes - else: - message_content = chat["message"] + message_notes - - message_content = construct_structured_message(message_content, chat.get("images"), model_type, vision_enabled) + message_content = construct_structured_message(chat["message"], chat.get("images"), model_type, vision_enabled) reconstructed_message = ChatMessage(content=message_content, role=role) - chatml_messages.insert(0, reconstructed_message) - if len(chatml_messages) >= 2 * lookback_turns: + if len(chatml_messages) >= 3 * lookback_turns: break messages = [] @@ -228,6 +236,8 @@ def generate_chatml_messages_with_context( role="user", ) ) + if not is_none_or_empty(context_message): + messages.append(ChatMessage(content=context_message, role="user")) if len(chatml_messages) > 0: messages += chatml_messages if not is_none_or_empty(system_message): diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 83881ddf..62a1f3b9 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -113,7 +113,7 @@ def add_files_filter(request: Request, filter: FilesFilterRequest): file_filters = ConversationAdapters.add_files_to_filter(request.user.object, conversation_id, files_filter) return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200) except Exception as e: - logger.error(f"Error adding file filter {filter.filename}: {e}", exc_info=True) + logger.error(f"Error adding file filter {filter.filenames}: {e}", exc_info=True) raise HTTPException(status_code=422, detail=str(e)) diff --git a/tests/test_openai_chat_actors.py b/tests/test_openai_chat_actors.py index fc253b50..b2ae2d34 100644 --- a/tests/test_openai_chat_actors.py +++ b/tests/test_openai_chat_actors.py @@ -214,7 +214,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(): ( "When was I born?", "You were born on 1st April 1984.", - ["Testatron was born on 1st April 1984 in Testville."], + [{"compiled": "Testatron was born on 1st April 1984 in Testville.", "file": "birth.org"}], ), ] @@ -415,15 +415,18 @@ def test_ask_for_clarification_if_not_enough_context_in_question(): context = [ { "compiled": f"""# Ramya -My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""" +My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""", + "file": "Family.md", }, { "compiled": f"""# Fang -My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""" +My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""", + "file": "Family.md", }, { "compiled": f"""# Aiyla -My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""" +My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""", + "file": "Family.md", }, ] @@ -608,9 +611,11 @@ async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client, defa ), ], ) -async def test_infer_task_scheduling_request(chat_client, user_query, expected_crontime, expected_qs, unexpected_qs): +async def test_infer_task_scheduling_request( + chat_client, user_query, expected_crontime, expected_qs, unexpected_qs, default_user2 +): # Act - crontime, inferred_query, _ = await schedule_query(user_query, {}) + crontime, inferred_query, _ = await schedule_query(user_query, {}, default_user2) inferred_query = inferred_query.lower() # Assert @@ -630,7 +635,7 @@ async def test_infer_task_scheduling_request(chat_client, user_query, expected_c "scheduling_query, executing_query, generated_response, expected_should_notify", [ ( - "Notify me if it is going to rain tomorrow?", + "Notify me only if it is going to rain tomorrow?", "What's the weather forecast for tomorrow?", "It is sunny and warm tomorrow.", False, @@ -656,10 +661,10 @@ async def test_infer_task_scheduling_request(chat_client, user_query, expected_c ], ) def test_decision_on_when_to_notify_scheduled_task_results( - chat_client, scheduling_query, executing_query, generated_response, expected_should_notify + chat_client, default_user2, scheduling_query, executing_query, generated_response, expected_should_notify ): # Act - generated_should_notify = should_notify(scheduling_query, executing_query, generated_response) + generated_should_notify = should_notify(scheduling_query, executing_query, generated_response, default_user2) # Assert assert generated_should_notify == expected_should_notify diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index 279d6e37..7d460408 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -307,7 +307,7 @@ def test_summarize_one_file(chat_client, default_user2: KhojUser): json={"filename": summarization_file, "conversation_id": str(conversation.id)}, ) query = "/summarize" - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id}) + response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response_message = response.json()["response"] # Assert assert response_message != "" @@ -339,7 +339,7 @@ def test_summarize_extra_text(chat_client, default_user2: KhojUser): json={"filename": summarization_file, "conversation_id": str(conversation.id)}, ) query = "/summarize tell me about Xiu" - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id}) + response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response_message = response.json()["response"] # Assert assert response_message != "" @@ -367,7 +367,7 @@ def test_summarize_multiple_files(chat_client, default_user2: KhojUser): ) query = "/summarize" - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id}) + response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response_message = response.json()["response"] # Assert @@ -383,7 +383,7 @@ def test_summarize_no_files(chat_client, default_user2: KhojUser): # Act query = "/summarize" - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id}) + response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response_message = response.json()["response"] # Assert @@ -418,11 +418,11 @@ def test_summarize_different_conversation(chat_client, default_user2: KhojUser): # Act query = "/summarize" - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation2.id}) + response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation2.id)}) response_message_conv2 = response.json()["response"] # now make sure that the file filter is still in conversation 1 - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation1.id}) + response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation1.id)}) response_message_conv1 = response.json()["response"] # Assert @@ -449,7 +449,7 @@ def test_summarize_nonexistant_file(chat_client, default_user2: KhojUser): json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)}, ) query = urllib.parse.quote("/summarize") - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id}) + response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response_message = response.json()["response"] # Assert assert response_message == "No files selected for summarization. Please add files using the section on the left." @@ -481,7 +481,7 @@ def test_summarize_diff_user_file(chat_client, default_user: KhojUser, pdf_confi # Act query = "/summarize" - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id}) + response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response_message = response.json()["response"] # Assert