diff --git a/pyproject.toml b/pyproject.toml index d0890e7b..193c0cc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,7 +112,7 @@ warn_unused_ignores = false line-length = 120 [tool.pytest.ini_options] -addopts = "--strict-markers -n 4" +addopts = "--strict-markers" markers = [ "chatquality: Evaluate chatbot capabilities and quality", ] diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ff2d88a2..780a6c57 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -702,10 +702,16 @@ async def chat( ) -> Response: perform_chat_checks() conversation_command = get_conversation_command(query=q, any_references=True) + + q = q.replace(f"/{conversation_command.value}", "").strip() + compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( request, q, (n or 5), conversation_command ) - conversation_command = get_conversation_command(query=q, any_references=not is_none_or_empty(compiled_references)) + + if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references): + conversation_command = ConversationCommand.General + if conversation_command == ConversationCommand.Help: model_type = "offline" if state.processor_config.conversation.enable_offline_chat else "openai" formatted_help = help_message.format(model=model_type, version=state.khoj_version) @@ -768,18 +774,16 @@ async def extract_references_and_questions( logger.warning( "No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes." ) - return compiled_references, inferred_queries + return compiled_references, inferred_queries, q if conversation_type == ConversationCommand.General: return compiled_references, inferred_queries, q # Extract filter terms from user message defiltered_query = q - filter_terms = [] for filter in [DateFilter(), WordFilter(), FileFilter()]: - filter_terms += filter.get_filter_terms(q) - defiltered_query = filter.defilter(q) - filters_in_query = " ".join(filter_terms) + defiltered_query = filter.defilter(defiltered_query) + filters_in_query = q.replace(defiltered_query, "").strip() # Infer search queries from user message with timer("Extracting search queries took", logger):