From ffdb291fe037d26c4d7207ab2a33a3f26d0b5d3b Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 10 Jan 2024 22:30:00 +0530 Subject: [PATCH 1/8] Fix error message rendering in khoj.el, khoj obsidian chat - Fix failed to index error message in khoj.el - Fix chat model not configured message in khoj obsidian chat --- src/interface/emacs/khoj.el | 2 +- src/interface/obsidian/src/chat_modal.ts | 2 +- src/khoj/routers/api.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/interface/emacs/khoj.el b/src/interface/emacs/khoj.el index ea2275f0..52cfba88 100644 --- a/src/interface/emacs/khoj.el +++ b/src/interface/emacs/khoj.el @@ -432,7 +432,7 @@ Auto invokes setup steps on calling main entrypoint." (khoj--delete-open-network-connections-to-server) (with-current-buffer (current-buffer) (search-forward "\n\n" nil t) - (message "khoj.el: Failed to %supdate %s content index. Status: %s%s" + (message "khoj.el: Failed to %supdate %scontent index. Status: %s%s" (if force "force " "") (if content-type (format "%s " content-type) "all") (string-trim (format "%s %s" (nth 1 (nth 1 status)) (nth 2 (nth 1 status)))) diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts index 57f0fa40..d433ee15 100644 --- a/src/interface/obsidian/src/chat_modal.ts +++ b/src/interface/obsidian/src/chat_modal.ts @@ -245,7 +245,7 @@ export class KhojChatModal extends Modal { if (responseJson.detail) { // If the server returns error details in response, render a setup hint. let setupMsg = "Hi 👋🏾, to start chatting add available chat models options via [the Django Admin panel](/server/admin) on the Server"; - this.renderMessage(chatBodyEl, setupMsg, "khoj", undefined, true); + this.renderMessage(chatBodyEl, setupMsg, "khoj", undefined); return false; } else if (responseJson.response) { diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 575f094c..190fef19 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -55,7 +55,7 @@ from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter from khoj.search_type import image_search, text_search from khoj.utils import constants, state -from khoj.utils.config import GPT4AllProcessorModel, TextSearchModel +from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.helpers import ( AsyncIteratorWrapper, ConversationCommand, From b1269fdad25c986800a4b83a60530cf9df4c4742 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 11 Jan 2024 21:56:36 +0530 Subject: [PATCH 2/8] Remove trailing slash to normalize khoj url in obsidian plugin settings --- src/interface/obsidian/src/settings.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface/obsidian/src/settings.ts b/src/interface/obsidian/src/settings.ts index 9150c438..198251ad 100644 --- a/src/interface/obsidian/src/settings.ts +++ b/src/interface/obsidian/src/settings.ts @@ -42,7 +42,7 @@ export class KhojSettingTab extends PluginSettingTab { .addText(text => text .setValue(`${this.plugin.settings.khojUrl}`) .onChange(async (value) => { - this.plugin.settings.khojUrl = value.trim(); + this.plugin.settings.khojUrl = value.trim().replace(/\/$/, ''); await this.plugin.saveSettings(); containerEl.firstElementChild?.setText(this.getBackendStatusMessage()); })); From ba99089a124c25fb20a0fe54679b98aa530fc7b7 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 11 Jan 2024 22:21:57 +0530 Subject: [PATCH 3/8] Short-circuit API rate limiter for unauthenticated user Calls by unauthenticated users were failing at API rate limiter as it failed to access user info object. This is a bug. API rate limiter should short-circuit for unauthenicated users so a proper Forbidden response can be returned by API Add regression test to verify that unauthenticated users get 403 response when calling the /chat API endpoint --- src/khoj/routers/helpers.py | 5 ++++ tests/conftest.py | 54 +++++++++++++++++-------------------- tests/test_client.py | 14 ++++++++++ 3 files changed, 43 insertions(+), 30 deletions(-) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 104426a3..d7a92a20 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -299,6 +299,11 @@ class ApiUserRateLimiter: self.cache: dict[str, list[float]] = defaultdict(list) def __call__(self, request: Request): + # Rate limiting is disabled if user unauthenticated. + # Other systems handle authentication + if not request.user.is_authenticated: + return + user: KhojUser = request.user.object subscribed = has_required_scope(request, ["premium"]) user_requests = self.cache[user.uuid] diff --git a/tests/conftest.py b/tests/conftest.py index f7756d99..a7ff1512 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -254,51 +254,45 @@ def md_content_config(): @pytest.fixture(scope="function") def chat_client(search_config: SearchConfig, default_user2: KhojUser): - # Initialize app state - state.config.search_type = search_config - state.SearchType = configure_search_types() + return chat_client_builder(search_config, default_user2, require_auth=False) - LocalMarkdownConfig.objects.create( - input_files=None, - input_filter=["tests/data/markdown/*.markdown"], - user=default_user2, - ) - # Index Markdown Content for Search - all_files = fs_syncer.collect_files(user=default_user2) - state.content_index, _ = configure_content( - state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2 - ) - - # Initialize Processor from Config - if os.getenv("OPENAI_API_KEY"): - chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai") - OpenAIProcessorConversationConfigFactory() - UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model) - - state.anonymous_mode = True - - app = FastAPI() - - configure_routes(app) - configure_middleware(app) - app.mount("/static", StaticFiles(directory=web_directory), name="static") - return TestClient(app) +@pytest.fixture(scope="function") +def chat_client_with_auth(search_config: SearchConfig, default_user2: KhojUser): + return chat_client_builder(search_config, default_user2, require_auth=True) @pytest.fixture(scope="function") def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser): + return chat_client_builder(search_config, default_user2, index_content=False, require_auth=False) + + +@pytest.mark.django_db +def chat_client_builder(search_config, user, index_content=True, require_auth=False): # Initialize app state state.config.search_type = search_config state.SearchType = configure_search_types() + if index_content: + LocalMarkdownConfig.objects.create( + input_files=None, + input_filter=["tests/data/markdown/*.markdown"], + user=user, + ) + + # Index Markdown Content for Search + all_files = fs_syncer.collect_files(user=user) + state.content_index, _ = configure_content( + state.content_index, state.config.content_type, all_files, state.search_models, user=user + ) + # Initialize Processor from Config if os.getenv("OPENAI_API_KEY"): chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai") OpenAIProcessorConversationConfigFactory() - UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model) + UserConversationProcessorConfigFactory(user=user, setting=chat_model) - state.anonymous_mode = True + state.anonymous_mode = not require_auth app = FastAPI() diff --git a/tests/test_client.py b/tests/test_client.py index 0bc3c02f..3954254a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -461,6 +461,20 @@ def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiU assert response.json() == [] +@pytest.mark.django_db(transaction=True) +def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser): + # Arrange + headers = {"Authorization": f"Bearer {api_user2.token}"} + + # Act + auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true', headers=headers) + no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true') + + # Assert + assert auth_response.status_code == 200 + assert no_auth_response.status_code == 403 + + def get_sample_files_data(): return [ ("files", ("path/to/filename.org", "* practicing piano", "text/org")), From bb1c1b39d8c495b786ed97dd0d6fed381531eb3d Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 12 Jan 2024 01:07:32 +0530 Subject: [PATCH 4/8] Move /api/config API controllers into separate module for code modularity --- src/khoj/configure.py | 2 + src/khoj/routers/api.py | 346 +------------------------------ src/khoj/routers/api_config.py | 359 +++++++++++++++++++++++++++++++++ 3 files changed, 363 insertions(+), 344 deletions(-) create mode 100644 src/khoj/routers/api_config.py diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 0f9e5fef..f5b6ec17 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -183,6 +183,7 @@ def configure_routes(app): # Import APIs here to setup search types before while configuring server from khoj.routers.api import api from khoj.routers.api_beta import api_beta + from khoj.routers.api_config import api_config from khoj.routers.auth import auth_router from khoj.routers.indexer import indexer from khoj.routers.subscription import subscription_router @@ -190,6 +191,7 @@ def configure_routes(app): app.include_router(api, prefix="/api") app.include_router(api_beta, prefix="/api/beta") + app.include_router(api_config, prefix="/api/config") app.include_router(indexer, prefix="/api/v1/index") if state.billing_enabled: logger.info("💳 Enabled Billing") diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 190fef19..fb125f1d 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -14,24 +14,12 @@ from fastapi.responses import Response, StreamingResponse from starlette.authentication import requires from khoj.configure import configure_server -from khoj.database import adapters from khoj.database.adapters import ( ConversationAdapters, EntryAdapters, get_user_search_model_or_default, ) -from khoj.database.models import ChatModelOptions -from khoj.database.models import Entry as DbEntry -from khoj.database.models import ( - GithubConfig, - KhojUser, - LocalMarkdownConfig, - LocalOrgConfig, - LocalPdfConfig, - LocalPlaintextConfig, - NotionConfig, - SpeechToTextModelOptions, -) +from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOptions from khoj.processor.conversation.offline.chat_model import extract_questions_offline from khoj.processor.conversation.offline.whisper import transcribe_audio_offline from khoj.processor.conversation.openai.gpt import extract_questions @@ -64,13 +52,7 @@ from khoj.utils.helpers import ( is_none_or_empty, timer, ) -from khoj.utils.rawconfig import ( - FullConfig, - GithubContentConfig, - NotionContentConfig, - SearchConfig, - SearchResponse, -) +from khoj.utils.rawconfig import SearchResponse from khoj.utils.state import SearchType # Initialize Router @@ -79,330 +61,6 @@ logger = logging.getLogger(__name__) conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100) -def map_config_to_object(content_source: str): - if content_source == DbEntry.EntrySource.GITHUB: - return GithubConfig - if content_source == DbEntry.EntrySource.GITHUB: - return NotionConfig - if content_source == DbEntry.EntrySource.COMPUTER: - return "Computer" - - -async def map_config_to_db(config: FullConfig, user: KhojUser): - if config.content_type: - if config.content_type.org: - await LocalOrgConfig.objects.filter(user=user).adelete() - await LocalOrgConfig.objects.acreate( - input_files=config.content_type.org.input_files, - input_filter=config.content_type.org.input_filter, - index_heading_entries=config.content_type.org.index_heading_entries, - user=user, - ) - if config.content_type.markdown: - await LocalMarkdownConfig.objects.filter(user=user).adelete() - await LocalMarkdownConfig.objects.acreate( - input_files=config.content_type.markdown.input_files, - input_filter=config.content_type.markdown.input_filter, - index_heading_entries=config.content_type.markdown.index_heading_entries, - user=user, - ) - if config.content_type.pdf: - await LocalPdfConfig.objects.filter(user=user).adelete() - await LocalPdfConfig.objects.acreate( - input_files=config.content_type.pdf.input_files, - input_filter=config.content_type.pdf.input_filter, - index_heading_entries=config.content_type.pdf.index_heading_entries, - user=user, - ) - if config.content_type.plaintext: - await LocalPlaintextConfig.objects.filter(user=user).adelete() - await LocalPlaintextConfig.objects.acreate( - input_files=config.content_type.plaintext.input_files, - input_filter=config.content_type.plaintext.input_filter, - index_heading_entries=config.content_type.plaintext.index_heading_entries, - user=user, - ) - if config.content_type.github: - await adapters.set_user_github_config( - user=user, - pat_token=config.content_type.github.pat_token, - repos=config.content_type.github.repos, - ) - if config.content_type.notion: - await adapters.set_notion_config( - user=user, - token=config.content_type.notion.token, - ) - - -def _initialize_config(): - if state.config is None: - state.config = FullConfig() - state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"]) - - -@api.get("/config/data", response_model=FullConfig) -@requires(["authenticated"]) -def get_config_data(request: Request): - user = request.user.object - EntryAdapters.get_unique_file_types(user) - - return state.config - - -@api.post("/config/data") -@requires(["authenticated"]) -async def set_config_data( - request: Request, - updated_config: FullConfig, - client: Optional[str] = None, -): - user = request.user.object - await map_config_to_db(updated_config, user) - - configuration_update_metadata = {} - - enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user) - - if state.config.content_type is not None: - configuration_update_metadata["github"] = "github" in enabled_content - configuration_update_metadata["notion"] = "notion" in enabled_content - configuration_update_metadata["org"] = "org" in enabled_content - configuration_update_metadata["pdf"] = "pdf" in enabled_content - configuration_update_metadata["markdown"] = "markdown" in enabled_content - - if state.config.processor is not None: - configuration_update_metadata["conversation_processor"] = state.config.processor.conversation is not None - - update_telemetry_state( - request=request, - telemetry_type="api", - api="set_config", - client=client, - metadata=configuration_update_metadata, - ) - return state.config - - -@api.post("/config/data/content-source/github", status_code=200) -@requires(["authenticated"]) -async def set_content_config_github_data( - request: Request, - updated_config: Union[GithubContentConfig, None], - client: Optional[str] = None, -): - _initialize_config() - - user = request.user.object - - try: - await adapters.set_user_github_config( - user=user, - pat_token=updated_config.pat_token, - repos=updated_config.repos, - ) - except Exception as e: - logger.error(e, exc_info=True) - raise HTTPException(status_code=500, detail="Failed to set Github config") - - update_telemetry_state( - request=request, - telemetry_type="api", - api="set_content_config", - client=client, - metadata={"content_type": "github"}, - ) - - return {"status": "ok"} - - -@api.post("/config/data/content-source/notion", status_code=200) -@requires(["authenticated"]) -async def set_content_config_notion_data( - request: Request, - updated_config: Union[NotionContentConfig, None], - client: Optional[str] = None, -): - _initialize_config() - - user = request.user.object - - try: - await adapters.set_notion_config( - user=user, - token=updated_config.token, - ) - except Exception as e: - logger.error(e, exc_info=True) - raise HTTPException(status_code=500, detail="Failed to set Github config") - - update_telemetry_state( - request=request, - telemetry_type="api", - api="set_content_config", - client=client, - metadata={"content_type": "notion"}, - ) - - return {"status": "ok"} - - -@api.delete("/config/data/content-source/{content_source}", status_code=200) -@requires(["authenticated"]) -async def remove_content_source_data( - request: Request, - content_source: str, - client: Optional[str] = None, -): - user = request.user.object - - update_telemetry_state( - request=request, - telemetry_type="api", - api="delete_content_config", - client=client, - metadata={"content_source": content_source}, - ) - - content_object = map_config_to_object(content_source) - if content_object is None: - raise ValueError(f"Invalid content source: {content_source}") - elif content_object != "Computer": - await content_object.objects.filter(user=user).adelete() - await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source) - - enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user) - return {"status": "ok"} - - -@api.delete("/config/data/file", status_code=200) -@requires(["authenticated"]) -async def remove_file_data( - request: Request, - filename: str, - client: Optional[str] = None, -): - user = request.user.object - - update_telemetry_state( - request=request, - telemetry_type="api", - api="delete_file", - client=client, - ) - - await EntryAdapters.adelete_entry_by_file(user, filename) - - return {"status": "ok"} - - -@api.get("/config/data/{content_source}", response_model=List[str]) -@requires(["authenticated"]) -async def get_all_filenames( - request: Request, - content_source: str, - client: Optional[str] = None, -): - user = request.user.object - - update_telemetry_state( - request=request, - telemetry_type="api", - api="get_all_filenames", - client=client, - ) - - return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) # type: ignore[call-arg] - - -@api.post("/config/data/conversation/model", status_code=200) -@requires(["authenticated"]) -async def update_chat_model( - request: Request, - id: str, - client: Optional[str] = None, -): - user = request.user.object - - new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id)) - - update_telemetry_state( - request=request, - telemetry_type="api", - api="set_conversation_chat_model", - client=client, - metadata={"processor_conversation_type": "conversation"}, - ) - - if new_config is None: - return {"status": "error", "message": "Model not found"} - - return {"status": "ok"} - - -@api.post("/config/data/search/model", status_code=200) -@requires(["authenticated"]) -async def update_search_model( - request: Request, - id: str, - client: Optional[str] = None, -): - user = request.user.object - - new_config = await adapters.aset_user_search_model(user, int(id)) - - if new_config is None: - return {"status": "error", "message": "Model not found"} - else: - update_telemetry_state( - request=request, - telemetry_type="api", - api="set_search_model", - client=client, - metadata={"search_model": new_config.setting.name}, - ) - - return {"status": "ok"} - - -# Create Routes -@api.get("/config/data/default") -def get_default_config_data(): - return constants.empty_config - - -@api.get("/config/index/size", response_model=Dict[str, int]) -@requires(["authenticated"]) -async def get_indexed_data_size(request: Request, common: CommonQueryParams): - user = request.user.object - indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user) - return Response( - content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}), - media_type="application/json", - status_code=200, - ) - - -@api.get("/config/types", response_model=List[str]) -@requires(["authenticated"]) -def get_config_types( - request: Request, -): - user = request.user.object - enabled_file_types = EntryAdapters.get_unique_file_types(user) - configured_content_types = list(enabled_file_types) - - if state.config and state.config.content_type: - for ctype in state.config.content_type.dict(exclude_none=True): - configured_content_types.append(ctype) - - return [ - search_type.value - for search_type in SearchType - if (search_type.value in configured_content_types) or search_type == SearchType.All - ] - - @api.get("/search", response_model=List[SearchResponse]) @requires(["authenticated"]) async def search( diff --git a/src/khoj/routers/api_config.py b/src/khoj/routers/api_config.py new file mode 100644 index 00000000..f7e4575e --- /dev/null +++ b/src/khoj/routers/api_config.py @@ -0,0 +1,359 @@ +import json +import logging +import math +from typing import Dict, List, Optional, Union + +from asgiref.sync import sync_to_async +from fastapi import APIRouter, HTTPException, Request +from fastapi.requests import Request +from fastapi.responses import Response +from starlette.authentication import requires + +from khoj.database import adapters +from khoj.database.adapters import ConversationAdapters, EntryAdapters +from khoj.database.models import Entry as DbEntry +from khoj.database.models import ( + GithubConfig, + KhojUser, + LocalMarkdownConfig, + LocalOrgConfig, + LocalPdfConfig, + LocalPlaintextConfig, + NotionConfig, +) +from khoj.routers.helpers import CommonQueryParams, update_telemetry_state +from khoj.utils import constants, state +from khoj.utils.rawconfig import ( + FullConfig, + GithubContentConfig, + NotionContentConfig, + SearchConfig, +) +from khoj.utils.state import SearchType + +api_config = APIRouter() +logger = logging.getLogger(__name__) + + +def map_config_to_object(content_source: str): + if content_source == DbEntry.EntrySource.GITHUB: + return GithubConfig + if content_source == DbEntry.EntrySource.GITHUB: + return NotionConfig + if content_source == DbEntry.EntrySource.COMPUTER: + return "Computer" + + +async def map_config_to_db(config: FullConfig, user: KhojUser): + if config.content_type: + if config.content_type.org: + await LocalOrgConfig.objects.filter(user=user).adelete() + await LocalOrgConfig.objects.acreate( + input_files=config.content_type.org.input_files, + input_filter=config.content_type.org.input_filter, + index_heading_entries=config.content_type.org.index_heading_entries, + user=user, + ) + if config.content_type.markdown: + await LocalMarkdownConfig.objects.filter(user=user).adelete() + await LocalMarkdownConfig.objects.acreate( + input_files=config.content_type.markdown.input_files, + input_filter=config.content_type.markdown.input_filter, + index_heading_entries=config.content_type.markdown.index_heading_entries, + user=user, + ) + if config.content_type.pdf: + await LocalPdfConfig.objects.filter(user=user).adelete() + await LocalPdfConfig.objects.acreate( + input_files=config.content_type.pdf.input_files, + input_filter=config.content_type.pdf.input_filter, + index_heading_entries=config.content_type.pdf.index_heading_entries, + user=user, + ) + if config.content_type.plaintext: + await LocalPlaintextConfig.objects.filter(user=user).adelete() + await LocalPlaintextConfig.objects.acreate( + input_files=config.content_type.plaintext.input_files, + input_filter=config.content_type.plaintext.input_filter, + index_heading_entries=config.content_type.plaintext.index_heading_entries, + user=user, + ) + if config.content_type.github: + await adapters.set_user_github_config( + user=user, + pat_token=config.content_type.github.pat_token, + repos=config.content_type.github.repos, + ) + if config.content_type.notion: + await adapters.set_notion_config( + user=user, + token=config.content_type.notion.token, + ) + + +def _initialize_config(): + if state.config is None: + state.config = FullConfig() + state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"]) + + +@api_config.get("/data", response_model=FullConfig) +@requires(["authenticated"]) +def get_config_data(request: Request): + user = request.user.object + EntryAdapters.get_unique_file_types(user) + + return state.config + + +@api_config.post("/data") +@requires(["authenticated"]) +async def set_config_data( + request: Request, + updated_config: FullConfig, + client: Optional[str] = None, +): + user = request.user.object + await map_config_to_db(updated_config, user) + + configuration_update_metadata = {} + + enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user) + + if state.config.content_type is not None: + configuration_update_metadata["github"] = "github" in enabled_content + configuration_update_metadata["notion"] = "notion" in enabled_content + configuration_update_metadata["org"] = "org" in enabled_content + configuration_update_metadata["pdf"] = "pdf" in enabled_content + configuration_update_metadata["markdown"] = "markdown" in enabled_content + + if state.config.processor is not None: + configuration_update_metadata["conversation_processor"] = state.config.processor.conversation is not None + + update_telemetry_state( + request=request, + telemetry_type="api", + api="set_config", + client=client, + metadata=configuration_update_metadata, + ) + return state.config + + +@api_config.post("/data/content-source/github", status_code=200) +@requires(["authenticated"]) +async def set_content_config_github_data( + request: Request, + updated_config: Union[GithubContentConfig, None], + client: Optional[str] = None, +): + _initialize_config() + + user = request.user.object + + try: + await adapters.set_user_github_config( + user=user, + pat_token=updated_config.pat_token, + repos=updated_config.repos, + ) + except Exception as e: + logger.error(e, exc_info=True) + raise HTTPException(status_code=500, detail="Failed to set Github config") + + update_telemetry_state( + request=request, + telemetry_type="api", + api="set_content_config", + client=client, + metadata={"content_type": "github"}, + ) + + return {"status": "ok"} + + +@api_config.post("/data/content-source/notion", status_code=200) +@requires(["authenticated"]) +async def set_content_config_notion_data( + request: Request, + updated_config: Union[NotionContentConfig, None], + client: Optional[str] = None, +): + _initialize_config() + + user = request.user.object + + try: + await adapters.set_notion_config( + user=user, + token=updated_config.token, + ) + except Exception as e: + logger.error(e, exc_info=True) + raise HTTPException(status_code=500, detail="Failed to set Github config") + + update_telemetry_state( + request=request, + telemetry_type="api", + api="set_content_config", + client=client, + metadata={"content_type": "notion"}, + ) + + return {"status": "ok"} + + +@api_config.delete("/data/content-source/{content_source}", status_code=200) +@requires(["authenticated"]) +async def remove_content_source_data( + request: Request, + content_source: str, + client: Optional[str] = None, +): + user = request.user.object + + update_telemetry_state( + request=request, + telemetry_type="api", + api="delete_content_config", + client=client, + metadata={"content_source": content_source}, + ) + + content_object = map_config_to_object(content_source) + if content_object is None: + raise ValueError(f"Invalid content source: {content_source}") + elif content_object != "Computer": + await content_object.objects.filter(user=user).adelete() + await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source) + + enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user) + return {"status": "ok"} + + +@api_config.delete("/data/file", status_code=200) +@requires(["authenticated"]) +async def remove_file_data( + request: Request, + filename: str, + client: Optional[str] = None, +): + user = request.user.object + + update_telemetry_state( + request=request, + telemetry_type="api", + api="delete_file", + client=client, + ) + + await EntryAdapters.adelete_entry_by_file(user, filename) + + return {"status": "ok"} + + +@api_config.get("/data/{content_source}", response_model=List[str]) +@requires(["authenticated"]) +async def get_all_filenames( + request: Request, + content_source: str, + client: Optional[str] = None, +): + user = request.user.object + + update_telemetry_state( + request=request, + telemetry_type="api", + api="get_all_filenames", + client=client, + ) + + return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) # type: ignore[call-arg] + + +@api_config.post("/data/conversation/model", status_code=200) +@requires(["authenticated"]) +async def update_chat_model( + request: Request, + id: str, + client: Optional[str] = None, +): + user = request.user.object + + new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id)) + + update_telemetry_state( + request=request, + telemetry_type="api", + api="set_conversation_chat_model", + client=client, + metadata={"processor_conversation_type": "conversation"}, + ) + + if new_config is None: + return {"status": "error", "message": "Model not found"} + + return {"status": "ok"} + + +@api_config.post("/data/search/model", status_code=200) +@requires(["authenticated"]) +async def update_search_model( + request: Request, + id: str, + client: Optional[str] = None, +): + user = request.user.object + + new_config = await adapters.aset_user_search_model(user, int(id)) + + if new_config is None: + return {"status": "error", "message": "Model not found"} + else: + update_telemetry_state( + request=request, + telemetry_type="api", + api="set_search_model", + client=client, + metadata={"search_model": new_config.setting.name}, + ) + + return {"status": "ok"} + + +# Create Routes +@api_config.get("/data/default") +def get_default_config_data(): + return constants.empty_config + + +@api_config.get("/index/size", response_model=Dict[str, int]) +@requires(["authenticated"]) +async def get_indexed_data_size(request: Request, common: CommonQueryParams): + user = request.user.object + indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user) + return Response( + content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}), + media_type="application/json", + status_code=200, + ) + + +@api_config.get("/types", response_model=List[str]) +@requires(["authenticated"]) +def get_config_types( + request: Request, +): + user = request.user.object + enabled_file_types = EntryAdapters.get_unique_file_types(user) + configured_content_types = list(enabled_file_types) + + if state.config and state.config.content_type: + for ctype in state.config.content_type.dict(exclude_none=True): + configured_content_types.append(ctype) + + return [ + search_type.value + for search_type in SearchType + if (search_type.value in configured_content_types) or search_type == SearchType.All + ] From 5f97357fe02581cb60e2309edddb29ee72c6e345 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 12 Jan 2024 01:08:51 +0530 Subject: [PATCH 5/8] Delete unused /api/beta API endpoint --- src/khoj/configure.py | 2 -- src/khoj/routers/api_beta.py | 7 ------- 2 files changed, 9 deletions(-) delete mode 100644 src/khoj/routers/api_beta.py diff --git a/src/khoj/configure.py b/src/khoj/configure.py index f5b6ec17..a67b1e3a 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -182,7 +182,6 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non def configure_routes(app): # Import APIs here to setup search types before while configuring server from khoj.routers.api import api - from khoj.routers.api_beta import api_beta from khoj.routers.api_config import api_config from khoj.routers.auth import auth_router from khoj.routers.indexer import indexer @@ -190,7 +189,6 @@ def configure_routes(app): from khoj.routers.web_client import web_client app.include_router(api, prefix="/api") - app.include_router(api_beta, prefix="/api/beta") app.include_router(api_config, prefix="/api/config") app.include_router(indexer, prefix="/api/v1/index") if state.billing_enabled: diff --git a/src/khoj/routers/api_beta.py b/src/khoj/routers/api_beta.py deleted file mode 100644 index 0d87ec62..00000000 --- a/src/khoj/routers/api_beta.py +++ /dev/null @@ -1,7 +0,0 @@ -import logging - -from fastapi import APIRouter - -# Initialize Router -api_beta = APIRouter() -logger = logging.getLogger(__name__) From 7dfbcd2e5a12a75daf9bcceffa9eb2763ac0654a Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 12 Jan 2024 01:32:46 +0530 Subject: [PATCH 6/8] Handle subscribe renew date, langchain, pydantic & logger.warn warnings - Ensure langchain less than 0.2.0 is used, to prevent breaking ChatOpenAI, PyMuPDF usage due to their deprecation after 0.2.0 - Set subscription renewal date to a timezone aware datetime - Use logger.warning instead of logger.warn as latter is deprecated - Use `model_dump' not deprecated dict to get all configured content_types --- pyproject.toml | 2 +- src/khoj/configure.py | 5 ++++- src/khoj/routers/api_config.py | 2 +- src/khoj/routers/auth.py | 2 +- src/khoj/routers/subscription.py | 2 +- tests/helpers.py | 4 +++- 6 files changed, 11 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 693415d4..caf3e410 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ dependencies = [ "torch == 2.0.1", "uvicorn == 0.17.6", "aiohttp ~= 3.9.0", - "langchain >= 0.0.331", + "langchain <= 0.2.0", "requests >= 2.26.0", "bs4 >= 0.0.1", "anyio == 3.7.1", diff --git a/src/khoj/configure.py b/src/khoj/configure.py index a67b1e3a..786eccb3 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -1,12 +1,14 @@ import json import logging import os +from datetime import datetime from enum import Enum from typing import Optional import openai import requests import schedule +from django.utils.timezone import make_aware from starlette.authentication import ( AuthCredentials, AuthenticationBackend, @@ -59,7 +61,8 @@ class UserAuthenticationBackend(AuthenticationBackend): email="default@example.com", password="default", ) - Subscription.objects.create(user=default_user, type="standard", renewal_date="2100-04-01") + renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d")) + Subscription.objects.create(user=default_user, type="standard", renewal_date=renewal_date) async def authenticate(self, request: HTTPConnection): current_user = request.session.get("user") diff --git a/src/khoj/routers/api_config.py b/src/khoj/routers/api_config.py index f7e4575e..fb53b638 100644 --- a/src/khoj/routers/api_config.py +++ b/src/khoj/routers/api_config.py @@ -349,7 +349,7 @@ def get_config_types( configured_content_types = list(enabled_file_types) if state.config and state.config.content_type: - for ctype in state.config.content_type.dict(exclude_none=True): + for ctype in state.config.content_type.model_dump(exclude_none=True): configured_content_types.append(ctype) return [ diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py index c7296316..02cd073d 100644 --- a/src/khoj/routers/auth.py +++ b/src/khoj/routers/auth.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) auth_router = APIRouter() if not state.anonymous_mode and not (os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET")): - logger.warn( + logger.warning( "🚨 Use --anonymous-mode flag to disable Google OAuth or set GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET environment variables to enable it" ) else: diff --git a/src/khoj/routers/subscription.py b/src/khoj/routers/subscription.py index 580fb225..c4542c1c 100644 --- a/src/khoj/routers/subscription.py +++ b/src/khoj/routers/subscription.py @@ -37,7 +37,7 @@ async def subscribe(request: Request): "customer.subscription.updated", "customer.subscription.deleted", }: - logger.warn(f"Unhandled Stripe event type: {event['type']}") + logger.warning(f"Unhandled Stripe event type: {event['type']}") return {"success": False} # Retrieve the customer's details diff --git a/tests/helpers.py b/tests/helpers.py index cc0f1ed3..321e08cf 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,6 +1,8 @@ import os +from datetime import datetime import factory +from django.utils.timezone import make_aware from khoj.database.models import ( ChatModelOptions, @@ -90,4 +92,4 @@ class SubscriptionFactory(factory.django.DjangoModelFactory): user = factory.SubFactory(UserFactory) type = "standard" is_recurring = False - renewal_date = "2100-04-01" + renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d")) From 8917228dbbc941e50847924792f62babffaba191 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 16 Jan 2024 18:15:06 +0530 Subject: [PATCH 7/8] Remove unused, deprecated /api/config/data API endpoints - Use /api/health for server up check instead of api/config/default - Remove unused `khoj--post-new-config' method - Remove the now unused /config/data GET, POST API endpoints --- src/interface/emacs/khoj.el | 18 +------------ src/khoj/routers/api_config.py | 48 ---------------------------------- 2 files changed, 1 insertion(+), 65 deletions(-) diff --git a/src/interface/emacs/khoj.el b/src/interface/emacs/khoj.el index 52cfba88..622fd209 100644 --- a/src/interface/emacs/khoj.el +++ b/src/interface/emacs/khoj.el @@ -348,7 +348,7 @@ Auto invokes setup steps on calling main entrypoint." t ;; else general check via ping to khoj-server-url (if (ignore-errors - (url-retrieve-synchronously (format "%s/api/config/data/default" khoj-server-url))) + (url-retrieve-synchronously (format "%s/api/health" khoj-server-url))) ;; Successful ping to non-emacs khoj server indicates it is started and ready. ;; So update ready state tracker variable (and implicitly return true for started) (setq khoj--server-ready? t) @@ -603,22 +603,6 @@ Use `BOUNDARY' to separate files. This is sent to Khoj server as a POST request. ;; -------------- ;; Query Khoj API ;; -------------- - -(defun khoj--post-new-config (config) - "Configure khoj server with provided CONFIG." - ;; POST provided config to khoj server - (let ((url-request-method "POST") - (url-request-extra-headers `(("Content-Type" . "application/json") - ("Authorization" . ,(format "Bearer %s" khoj-api-key)))) - (url-request-data (encode-coding-string (json-encode-alist config) 'utf-8)) - (config-url (format "%s/api/config/data" khoj-server-url))) - (with-current-buffer (url-retrieve-synchronously config-url) - (buffer-string))) - ;; Update index on khoj server after configuration update - (let ((khoj--server-ready? nil) - (url-request-extra-headers `(("Authorization" . ,(format "\"Bearer %s\"" khoj-api-key))))) - (url-retrieve (format "%s/api/update?client=emacs" khoj-server-url) #'identity))) - (defun khoj--get-enabled-content-types () "Get content types enabled for search from API." (let ((config-url (format "%s/api/config/types" khoj-server-url)) diff --git a/src/khoj/routers/api_config.py b/src/khoj/routers/api_config.py index fb53b638..6169eb4a 100644 --- a/src/khoj/routers/api_config.py +++ b/src/khoj/routers/api_config.py @@ -97,49 +97,6 @@ def _initialize_config(): state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"]) -@api_config.get("/data", response_model=FullConfig) -@requires(["authenticated"]) -def get_config_data(request: Request): - user = request.user.object - EntryAdapters.get_unique_file_types(user) - - return state.config - - -@api_config.post("/data") -@requires(["authenticated"]) -async def set_config_data( - request: Request, - updated_config: FullConfig, - client: Optional[str] = None, -): - user = request.user.object - await map_config_to_db(updated_config, user) - - configuration_update_metadata = {} - - enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user) - - if state.config.content_type is not None: - configuration_update_metadata["github"] = "github" in enabled_content - configuration_update_metadata["notion"] = "notion" in enabled_content - configuration_update_metadata["org"] = "org" in enabled_content - configuration_update_metadata["pdf"] = "pdf" in enabled_content - configuration_update_metadata["markdown"] = "markdown" in enabled_content - - if state.config.processor is not None: - configuration_update_metadata["conversation_processor"] = state.config.processor.conversation is not None - - update_telemetry_state( - request=request, - telemetry_type="api", - api="set_config", - client=client, - metadata=configuration_update_metadata, - ) - return state.config - - @api_config.post("/data/content-source/github", status_code=200) @requires(["authenticated"]) async def set_content_config_github_data( @@ -322,11 +279,6 @@ async def update_search_model( # Create Routes -@api_config.get("/data/default") -def get_default_config_data(): - return constants.empty_config - - @api_config.get("/index/size", response_model=Dict[str, int]) @requires(["authenticated"]) async def get_indexed_data_size(request: Request, common: CommonQueryParams): From d26a4ffcea29c4145d3dcd4bdb122b04e8c3f854 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 17 Jan 2024 00:36:03 +0530 Subject: [PATCH 8/8] Only run the OpenAI chat client, /online test when API keys are set --- tests/test_client.py | 2 ++ tests/test_openai_chat_director.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_client.py b/tests/test_client.py index 3954254a..bfa9b578 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,4 +1,5 @@ # Standard Modules +import os from io import BytesIO from urllib.parse import quote @@ -461,6 +462,7 @@ def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiU assert response.json() == [] +@pytest.mark.skipif(os.getenv("OPENAI_API_KEY") is None, reason="requires OPENAI_API_KEY") @pytest.mark.django_db(transaction=True) def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser): # Arrange diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index 181080ef..a1ece5a7 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -53,6 +53,7 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.skipif(os.getenv("SERPER_DEV_API_KEY") is None, reason="requires SERPER_DEV_API_KEY") @pytest.mark.chatquality @pytest.mark.django_db(transaction=True) def test_chat_with_online_content(chat_client): @@ -65,7 +66,7 @@ def test_chat_with_online_content(chat_client): response_message = response_message.split("### compiled references")[0] # Assert - expected_responses = ["http://www.paulgraham.com/greatwork.html", "Please set your SERPER_DEV_API_KEY"] + expected_responses = ["http://www.paulgraham.com/greatwork.html"] assert response.status_code == 200 assert any([expected_response in response_message for expected_response in expected_responses]), ( "Expected links or serper not setup in response but got: " + response_message