diff --git a/src/khoj/configure.py b/src/khoj/configure.py index d3cd204c..27067a4b 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -73,7 +73,9 @@ def configure_search_types(config: FullConfig): # Extract core search types core_search_types = {e.name: e.value for e in SearchType} # Extract configured plugin search types - plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()} + plugin_search_types = {} + if config.content_type.plugins: + plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()} # Dynamically generate search type enum by merging core search types with configured plugin search types return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types)) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index b9f2a3b8..033ea220 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -29,11 +29,12 @@ def get_default_config_data(): @api.get("/config/types", response_model=List[str]) def get_config_types(): """Get configured content types""" + configured_content_types = state.config.content_type.dict(exclude_none=True) return [ search_type.value for search_type in SearchType - if any(search_type.value == ctype[0] and ctype[1] for ctype in state.config.content_type) - or search_type.name in state.config.content_type.plugins.keys() + if search_type.value in configured_content_types + or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"]) ] diff --git a/tests/test_client.py b/tests/test_client.py index c9510c82..e7087e2c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,8 +9,9 @@ from fastapi.testclient import TestClient # Internal Packages from khoj.main import app -from khoj.configure import configure_routes -from khoj.utils.state import model +from khoj.configure import configure_routes, configure_search_types +from khoj.utils import state +from khoj.utils.state import model, config from khoj.search_type import text_search, image_search from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl @@ -86,6 +87,59 @@ def test_get_configured_types_via_api(client): assert response.json() == ["org", "image", "plugin1"] +# ---------------------------------------------------------------------------------------------------- +def test_get_configured_types_with_only_plugin_content_config(content_config): + # Arrange + config.content_type = ContentConfig() + config.content_type.plugins = content_config.plugins + state.SearchType = configure_search_types(config) + + configure_routes(app) + client = TestClient(app) + + # Act + response = client.get(f"/api/config/types") + + # Assert + assert response.status_code == 200 + assert response.json() == ["plugin1"] + + +# ---------------------------------------------------------------------------------------------------- +def test_get_configured_types_with_no_plugin_content_config(content_config): + # Arrange + config.content_type = content_config + config.content_type.plugins = None + state.SearchType = configure_search_types(config) + + configure_routes(app) + client = TestClient(app) + + # Act + response = client.get(f"/api/config/types") + + # Assert + assert response.status_code == 200 + assert "plugin1" not in response.json() + + +# ---------------------------------------------------------------------------------------------------- +def test_get_configured_types_with_no_content_config(): + # Arrange + config.content_type = ContentConfig() + state.SearchType = configure_search_types(config) + + configure_routes(app) + client = TestClient(app) + + # Act + response = client.get(f"/api/config/types") + + # Assert + assert response.status_code == 200 + assert response.json() == [] + + # ---------------------------------------------------------------------------------------------------- def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange