diff --git a/src/app/README.md b/src/app/README.md index 7a93ee8b..cbfe5356 100644 --- a/src/app/README.md +++ b/src/app/README.md @@ -37,6 +37,12 @@ make install # may need sudo ``` 3. Create a database +### Create the khoj database + +```bash +createdb khoj -U postgres +``` + ### Make migrations This command will create the migrations for the database app. This command should be run whenever a new model is added to the database app or an existing model is modified (updated or deleted). diff --git a/src/app/settings.py b/src/app/settings.py index cfd7cd3c..9a8b427b 100644 --- a/src/app/settings.py +++ b/src/app/settings.py @@ -14,7 +14,7 @@ from pathlib import Path import os # Build paths inside the project like this: BASE_DIR / 'subdir'. -BASE_DIR = Path(__file__).resolve().parent.parent +BASE_DIR = Path(__file__).resolve().parent.parent.parent # Quick-start development settings - unsuitable for production @@ -123,8 +123,8 @@ USE_TZ = True # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/4.2/howto/static-files/ -STATIC_ROOT = os.path.join(BASE_DIR, "static") -STATICFILES_DIRS = [os.path.join(BASE_DIR, "khoj/interface/web")] +STATIC_ROOT = BASE_DIR / "static" +STATICFILES_DIRS = [BASE_DIR / "src/khoj/interface/web"] STATIC_URL = "/static/" # Default primary key field type diff --git a/src/app/urls.py b/src/app/urls.py index fbd67a4e..39b4b1ef 100644 --- a/src/app/urls.py +++ b/src/app/urls.py @@ -15,7 +15,7 @@ Including another URLconf 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) """ from django.contrib import admin -from django.urls import path, include +from django.urls import path from django.contrib.staticfiles.urls import staticfiles_urlpatterns urlpatterns = [ diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index db5e9f77..52debdc4 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -1,3 +1,4 @@ +import secrets from typing import Type, TypeVar, List from datetime import date @@ -16,6 +17,7 @@ from fastapi import HTTPException from database.models import ( KhojUser, GoogleUser, + KhojApiUser, NotionConfig, GithubConfig, Embeddings, @@ -25,6 +27,7 @@ from database.models import ( OpenAIProcessorConversationConfig, OfflineChatProcessorConversationConfig, ) +from khoj.utils.helpers import generate_random_name from khoj.utils.rawconfig import ( ConversationProcessorConfig as UserConversationProcessorConfig, ) @@ -52,6 +55,25 @@ async def set_notion_config(token: str, user: KhojUser): return notion_config +async def create_khoj_token(user: KhojUser, name=None): + "Create Khoj API key for user" + token = f"kk-{secrets.token_urlsafe(32)}" + name = name or f"{generate_random_name().title()}'s Secret Key" + api_config = await KhojApiUser.objects.acreate(token=token, user=user, name=name) + await api_config.asave() + return api_config + + +def get_khoj_tokens(user: KhojUser): + "Get all Khoj API keys for user" + return list(KhojApiUser.objects.filter(user=user)) + + +async def delete_khoj_token(user: KhojUser, token: str): + "Delete Khoj API Key for user" + await KhojApiUser.objects.filter(token=token, user=user).adelete() + + async def get_or_create_user(token: dict) -> KhojUser: user = await get_user_by_token(token) if not user: diff --git a/src/database/migrations/0009_khojapiuser.py b/src/database/migrations/0009_khojapiuser.py new file mode 100644 index 00000000..86b09ab3 --- /dev/null +++ b/src/database/migrations/0009_khojapiuser.py @@ -0,0 +1,24 @@ +# Generated by Django 4.2.5 on 2023-10-26 17:02 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0008_alter_conversation_conversation_log"), + ] + + operations = [ + migrations.CreateModel( + name="KhojApiUser", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("token", models.CharField(max_length=50, unique=True)), + ("name", models.CharField(max_length=50)), + ("accessed_at", models.DateTimeField(default=None, null=True)), + ("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + ), + ] diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py index a9d41e0d..7c9c3822 100644 --- a/src/database/models/__init__.py +++ b/src/database/models/__init__.py @@ -37,6 +37,15 @@ class GoogleUser(models.Model): return self.name +class KhojApiUser(models.Model): + """User issued API tokens to authenticate Khoj clients""" + + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) + token = models.CharField(max_length=50, unique=True) + name = models.CharField(max_length=50) + accessed_at = models.DateTimeField(null=True, default=None) + + class NotionConfig(BaseModel): token = models.CharField(max_length=200) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) diff --git a/src/interface/desktop/assets/icons/favicon-20x20.png b/src/interface/desktop/assets/icons/favicon-20x20.png new file mode 100644 index 00000000..1a4ee0be Binary files /dev/null and b/src/interface/desktop/assets/icons/favicon-20x20.png differ diff --git a/src/interface/desktop/assets/icons/key.svg b/src/interface/desktop/assets/icons/key.svg new file mode 100644 index 00000000..437688fb --- /dev/null +++ b/src/interface/desktop/assets/icons/key.svg @@ -0,0 +1,4 @@ + + + + diff --git a/src/interface/desktop/assets/icons/link.svg b/src/interface/desktop/assets/icons/link.svg index ef484368..43852d95 100644 --- a/src/interface/desktop/assets/icons/link.svg +++ b/src/interface/desktop/assets/icons/link.svg @@ -1,5 +1,4 @@ - + - - + diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 21a1a416..9ae3b365 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -89,6 +89,8 @@ // Generate backend API URL to execute query let url = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true`; + const khojToken = await window.tokenAPI.getToken(); + const headers = { 'Authorization': `Bearer ${khojToken}` }; let chat_body = document.getElementById("chat-body"); let new_response = document.createElement("div"); @@ -113,7 +115,7 @@ chatInput.classList.remove("option-enabled"); // Call specified Khoj API which returns a streamed response of type text/plain - fetch(url) + fetch(url, { headers }) .then(response => { const reader = response.body.getReader(); const decoder = new TextDecoder(); @@ -217,7 +219,10 @@ async function loadChat() { const hostURL = await window.hostURLAPI.getURL(); - fetch(`${hostURL}/api/chat/history?client=web`) + const khojToken = await window.tokenAPI.getToken(); + const headers = { 'Authorization': `Bearer ${khojToken}` }; + + fetch(`${hostURL}/api/chat/history?client=web`, { headers }) .then(response => response.json()) .then(data => { if (data.detail) { @@ -243,7 +248,7 @@ return; }); - fetch(`${hostURL}/api/chat/options`) + fetch(`${hostURL}/api/chat/options`, { headers }) .then(response => response.json()) .then(data => { // Render chat options, if any @@ -272,9 +277,9 @@ diff --git a/src/interface/desktop/config.html b/src/interface/desktop/config.html index 04599bb1..4b79f1a1 100644 --- a/src/interface/desktop/config.html +++ b/src/interface/desktop/config.html @@ -12,66 +12,85 @@ -
- -
- - -
+ +
+ +
-
-
- File -

- Host -

+
+
+
+ Khoj Server URL +

+ Server URL +

+
+
+ +
+
+ Khoj Access Key +

+ Access Key +

+
+
+ +
-
- -
-
- File -

- Files -

+
+
+
+ File +

+ Files + +

+
+
+
+
+
+ - +
-
-
-
-
- -
-
- Folder -

- Folders -

+
+
+
+ Folder +

+ Folders + +

+
+
+
+
+
+ - -
-
-
-
-
- +
+
+
@@ -79,11 +98,10 @@
- -
-
-
+
+ +
+
@@ -93,7 +111,7 @@ body { display: grid; grid-template-columns: 1fr; - grid-template-rows: 1fr auto auto auto minmax(80px, 100%); + grid-template-rows: 1fr auto; font-size: small!important; } body > * { @@ -104,8 +122,7 @@ body { display: grid; grid-template-columns: 1fr min(70vw, 100%) 1fr; - grid-template-rows: 1fr auto auto auto minmax(80px, 100%); - padding-top: 60vw; + grid-template-rows: 80px auto; } body > * { grid-column: 2; @@ -126,11 +143,6 @@ margin: 10px; } - div.page { - padding: 0px; - margin: 0px; - } - svg { transition: transform 0.3s ease-in-out; } @@ -167,18 +179,18 @@ } } - #khoj-host-url { + .card-input { padding: 4px; box-shadow: 0 0 2px 1px rgba(0, 0, 0, 0.2); border: none; + width: 450px; } .card { display: grid; - /* grid-template-rows: repeat(3, 1fr); */ gap: 8px; padding: 24px 16px; - width: 100%; + width: 450px; background: white; border: 1px solid rgb(229, 229, 229); border-radius: 4px; @@ -188,15 +200,15 @@ .section-cards { display: grid; - grid-template-columns: repeat(1, 1fr); gap: 16px; - justify-items: start; + justify-items: center; margin: 0; - width: auto; } - - div.configuration { - width: auto; + .section-action-row { + display: grid; + grid-auto-flow: column; + gap: 16px; + height: fit-content; } .card-title-row { @@ -302,7 +314,6 @@ } div.content-name { - width: 500px; overflow-wrap: break-word; } @@ -347,6 +358,12 @@ background-color: #ffcc00; box-shadow: 0px 3px 0px #f9f5de; } + .sync-force-toggle { + align-content: center; + display: grid; + grid-auto-flow: column; + gap: 4px; + } {% endblock %} diff --git a/src/khoj/main.py b/src/khoj/main.py index 804e71e5..8fe40e76 100644 --- a/src/khoj/main.py +++ b/src/khoj/main.py @@ -98,9 +98,10 @@ def run(): # Mount Django and Static Files app.mount("/django", django_app, name="django") - if not os.path.exists("static"): - os.mkdir("static") - app.mount("/static", StaticFiles(directory="static"), name="static") + static_dir = "static" + if not os.path.exists(static_dir): + os.mkdir(static_dir) + app.mount(f"/{static_dir}", StaticFiles(directory=static_dir), name=static_dir) # Configure Middleware configure_middleware(app) diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/gpt4all/utils.py index d5201780..cd9bc9e2 100644 --- a/src/khoj/processor/conversation/gpt4all/utils.py +++ b/src/khoj/processor/conversation/gpt4all/utils.py @@ -6,17 +6,23 @@ logger = logging.getLogger(__name__) def download_model(model_name: str): try: - from gpt4all import GPT4All + import gpt4all except ModuleNotFoundError as e: logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") raise e - # Use GPU for Chat Model, if available - try: - model = GPT4All(model_name=model_name, device="gpu") - logger.debug("Loaded chat model to GPU.") - except ValueError: - model = GPT4All(model_name=model_name) - logger.debug("Loaded chat model to CPU.") + # Download the chat model + chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True) - return model + # Decide whether to load model to GPU or CPU + try: + # Check if machine has GPU and GPU has enough free memory to load the chat model + device = "gpu" if gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"]) else "cpu" + except ValueError: + device = "cpu" + + # Now load the downloaded chat model onto appropriate device + chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False) + logger.debug(f"Loaded chat model to {device.upper()}.") + + return chat_model diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index f0e2df77..fbcddb67 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -1,28 +1,17 @@ from typing import List -import torch from langchain.embeddings import HuggingFaceEmbeddings from sentence_transformers import CrossEncoder +from khoj.utils.helpers import get_device from khoj.utils.rawconfig import SearchResponse class EmbeddingsModel: def __init__(self): - self.model_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" - encode_kwargs = {"normalize_embeddings": True} - # encode_kwargs = {} - - if torch.cuda.is_available(): - # Use CUDA GPU - device = torch.device("cuda:0") - elif torch.backends.mps.is_available(): - # Use Apple M1 Metal Acceleration - device = torch.device("mps") - else: - device = torch.device("cpu") - - model_kwargs = {"device": device} + self.model_name = "thenlper/gte-small" + encode_kwargs = {"normalize_embeddings": True, "show_progress_bar": True} + model_kwargs = {"device": get_device()} self.embeddings_model = HuggingFaceEmbeddings( model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs ) @@ -37,19 +26,7 @@ class EmbeddingsModel: class CrossEncoderModel: def __init__(self): self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2" - - if torch.cuda.is_available(): - # Use CUDA GPU - device = torch.device("cuda:0") - - elif torch.backends.mps.is_available(): - # Use Apple M1 Metal Acceleration - device = torch.device("mps") - - else: - device = torch.device("cpu") - - self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=device) + self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device()) def predict(self, query, hits: List[SearchResponse]): cross__inp = [[query, hit.additional["compiled"]] for hit in hits] diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index b7ba66b6..4984ea4c 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -126,21 +126,21 @@ if not state.demo: state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"]) @api.get("/config/data", response_model=FullConfig) - @requires(["authenticated"], redirect="login_page") + @requires(["authenticated"]) def get_config_data(request: Request): - user = request.user.object if request.user.is_authenticated else None - enabled_content = EmbeddingsAdapters.get_unique_file_types(user) + user = request.user.object + EmbeddingsAdapters.get_unique_file_types(user) return state.config @api.post("/config/data") - @requires(["authenticated"], redirect="login_page") + @requires(["authenticated"]) async def set_config_data( request: Request, updated_config: FullConfig, client: Optional[str] = None, ): - user = request.user.object if request.user.is_authenticated else None + user = request.user.object await map_config_to_db(updated_config, user) configuration_update_metadata = {} @@ -167,7 +167,7 @@ if not state.demo: return state.config @api.post("/config/data/content_type/github", status_code=200) - @requires(["authenticated"], redirect="login_page") + @requires(["authenticated"]) async def set_content_config_github_data( request: Request, updated_config: Union[GithubContentConfig, None], @@ -175,7 +175,7 @@ if not state.demo: ): _initialize_config() - user = request.user.object if request.user.is_authenticated else None + user = request.user.object await adapters.set_user_github_config( user=user, @@ -194,7 +194,7 @@ if not state.demo: return {"status": "ok"} @api.post("/config/data/content_type/notion", status_code=200) - @requires(["authenticated"], redirect="login_page") + @requires(["authenticated"]) async def set_content_config_notion_data( request: Request, updated_config: Union[NotionContentConfig, None], @@ -202,7 +202,7 @@ if not state.demo: ): _initialize_config() - user = request.user.object if request.user.is_authenticated else None + user = request.user.object await adapters.set_notion_config( user=user, @@ -220,13 +220,13 @@ if not state.demo: return {"status": "ok"} @api.post("/delete/config/data/content_type/{content_type}", status_code=200) - @requires(["authenticated"], redirect="login_page") + @requires(["authenticated"]) async def remove_content_config_data( request: Request, content_type: str, client: Optional[str] = None, ): - user = request.user.object if request.user.is_authenticated else None + user = request.user.object update_telemetry_state( request=request, @@ -247,7 +247,7 @@ if not state.demo: return {"status": "ok"} @api.post("/delete/config/data/processor/conversation/openai", status_code=200) - @requires(["authenticated"], redirect="login_page") + @requires(["authenticated"]) async def remove_processor_conversation_config_data( request: Request, client: Optional[str] = None, @@ -267,7 +267,7 @@ if not state.demo: return {"status": "ok"} @api.post("/config/data/content_type/{content_type}", status_code=200) - @requires(["authenticated"], redirect="login_page") + @requires(["authenticated"]) async def set_content_config_data( request: Request, content_type: str, @@ -276,7 +276,7 @@ if not state.demo: ): _initialize_config() - user = request.user.object if request.user.is_authenticated else None + user = request.user.object content_object = map_config_to_object(content_type) await adapters.set_text_content_config(user, content_object, updated_config) @@ -292,7 +292,7 @@ if not state.demo: return {"status": "ok"} @api.post("/config/data/processor/conversation/openai", status_code=200) - @requires(["authenticated"], redirect="login_page") + @requires(["authenticated"]) async def set_processor_openai_config_data( request: Request, updated_config: Union[OpenAIProcessorConfig, None], @@ -315,6 +315,7 @@ if not state.demo: return {"status": "ok"} @api.post("/config/data/processor/conversation/offline_chat", status_code=200) + @requires(["authenticated"]) async def set_processor_enable_offline_chat_config_data( request: Request, enable_offline_chat: bool, @@ -323,24 +324,29 @@ if not state.demo: ): user = request.user.object - if enable_offline_chat: - conversation_config = ConversationProcessorConfig( - offline_chat=OfflineChatProcessorConfig( - enable_offline_chat=enable_offline_chat, - chat_model=offline_chat_model, + try: + if enable_offline_chat: + conversation_config = ConversationProcessorConfig( + offline_chat=OfflineChatProcessorConfig( + enable_offline_chat=enable_offline_chat, + chat_model=offline_chat_model, + ) ) - ) - await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config) + await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config) - offline_chat = await ConversationAdapters.get_offline_chat(user) - chat_model = offline_chat.chat_model - if state.gpt4all_processor_config is None: - state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model) + offline_chat = await ConversationAdapters.get_offline_chat(user) + chat_model = offline_chat.chat_model + if state.gpt4all_processor_config is None: + state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model) - else: - await sync_to_async(ConversationAdapters.clear_offline_chat_conversation_config)(user) - state.gpt4all_processor_config = None + else: + await sync_to_async(ConversationAdapters.clear_offline_chat_conversation_config)(user) + state.gpt4all_processor_config = None + + except Exception as e: + logger.error(f"Error updating offline chat config: {e}", exc_info=True) + return {"status": "error", "message": str(e)} update_telemetry_state( request=request, @@ -360,11 +366,11 @@ def get_default_config_data(): @api.get("/config/types", response_model=List[str]) -@requires(["authenticated"], redirect="login_page") +@requires(["authenticated"]) def get_config_types( request: Request, ): - user = request.user.object if request.user.is_authenticated else None + user = request.user.object enabled_file_types = EmbeddingsAdapters.get_unique_file_types(user) @@ -382,7 +388,7 @@ def get_config_types( @api.get("/search", response_model=List[SearchResponse]) -@requires(["authenticated"], redirect="login_page") +@requires(["authenticated"]) async def search( q: str, request: Request, @@ -396,7 +402,7 @@ async def search( referer: Optional[str] = Header(None), host: Optional[str] = Header(None), ): - user = request.user.object if request.user.is_authenticated else None + user = request.user.object start_time = time.time() # Run validation checks @@ -513,7 +519,7 @@ async def search( @api.get("/update") -@requires(["authenticated"], redirect="login_page") +@requires(["authenticated"]) def update( request: Request, t: Optional[SearchType] = None, @@ -523,7 +529,7 @@ def update( referer: Optional[str] = Header(None), host: Optional[str] = Header(None), ): - user = request.user.object if request.user.is_authenticated else None + user = request.user.object if not state.config: error_msg = f"๐Ÿšจ Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}." logger.warning(error_msg) @@ -557,7 +563,7 @@ def update( @api.get("/chat/history") -@requires(["authenticated"], redirect="login_page") +@requires(["authenticated"]) def chat_history( request: Request, client: Optional[str] = None, @@ -585,7 +591,7 @@ def chat_history( @api.get("/chat/options", response_class=Response) -@requires(["authenticated"], redirect="login_page") +@requires(["authenticated"]) async def chat_options( request: Request, client: Optional[str] = None, @@ -610,7 +616,7 @@ async def chat_options( @api.get("/chat", response_class=Response) -@requires(["authenticated"], redirect="login_page") +@requires(["authenticated"]) async def chat( request: Request, q: str, diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py index 8c767d8f..ab1964b8 100644 --- a/src/khoj/routers/auth.py +++ b/src/khoj/routers/auth.py @@ -1,22 +1,29 @@ +# Standard Packages import logging -import json import os +from typing import Optional + +# External Packages from fastapi import APIRouter from starlette.config import Config from starlette.requests import Request from starlette.responses import HTMLResponse, RedirectResponse, Response +from starlette.authentication import requires from authlib.integrations.starlette_client import OAuth, OAuthError from google.oauth2 import id_token from google.auth.transport import requests as google_requests -from database.adapters import get_or_create_user +# Internal Packages +from database.adapters import get_khoj_tokens, get_or_create_user, create_khoj_token, delete_khoj_token +from khoj.utils import state + logger = logging.getLogger(__name__) auth_router = APIRouter() -if not os.environ.get("GOOGLE_CLIENT_ID") or not os.environ.get("GOOGLE_CLIENT_SECRET"): +if not state.anonymous_mode and not (os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET")): logger.info("Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET environment variables to use Google OAuth") else: config = Config(environ=os.environ) @@ -39,6 +46,31 @@ async def login(request: Request): return await oauth.google.authorize_redirect(request, redirect_uri) +@auth_router.post("/token") +@requires(["authenticated"], redirect="login_page") +async def generate_token(request: Request, token_name: Optional[str] = None) -> str: + "Generate API token for given user" + if token_name: + return await create_khoj_token(user=request.user.object, name=token_name) + else: + return await create_khoj_token(user=request.user.object) + + +@auth_router.get("/token") +@requires(["authenticated"], redirect="login_page") +def get_tokens(request: Request): + "Get API tokens enabled for given user" + tokens = get_khoj_tokens(user=request.user.object) + return tokens + + +@auth_router.delete("/token") +@requires(["authenticated"], redirect="login_page") +async def delete_token(request: Request, token: str) -> str: + "Delete API token for given user" + return await delete_khoj_token(user=request.user.object, token=token) + + @auth_router.post("/redirect") async def auth(request: Request): form = await request.form() diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py index 1125e653..e7df65a2 100644 --- a/src/khoj/routers/indexer.py +++ b/src/khoj/routers/indexer.py @@ -4,9 +4,9 @@ from typing import Optional, Union, Dict import asyncio # External Packages -from fastapi import APIRouter, HTTPException, Header, Request, Response, UploadFile +from fastapi import APIRouter, Header, Request, Response, UploadFile from pydantic import BaseModel -from khoj.routers.helpers import update_telemetry_state +from starlette.authentication import requires # Internal Packages from khoj.utils import state, constants @@ -17,6 +17,7 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl from khoj.processor.notion.notion_to_jsonl import NotionToJsonl from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl from khoj.search_type import text_search, image_search +from khoj.routers.helpers import update_telemetry_state from khoj.utils.yaml import save_config_to_file_updated_state from khoj.utils.config import SearchModels from khoj.utils.helpers import LRU, get_file_type @@ -57,10 +58,10 @@ class IndexerInput(BaseModel): @indexer.post("/update") +@requires(["authenticated"]) async def update( request: Request, files: list[UploadFile], - x_api_key: str = Header(None), force: bool = False, t: Optional[Union[state.SearchType, str]] = None, client: Optional[str] = None, @@ -68,9 +69,7 @@ async def update( referer: Optional[str] = Header(None), host: Optional[str] = Header(None), ): - user = request.user.object if request.user.is_authenticated else None - if x_api_key != "secret": - raise HTTPException(status_code=401, detail="Invalid API Key") + user = request.user.object try: logger.info(f"๐Ÿ“ฌ Updating content index via API call by {client} client") org_files: Dict[str, str] = {} diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 333d89fa..06f43430 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -135,7 +135,7 @@ if not state.demo: @web_client.get("/config/content_type/github", response_class=HTMLResponse) @requires(["authenticated"], redirect="login_page") def github_config_page(request: Request): - user = request.user.object if request.user.is_authenticated else None + user = request.user.object current_github_config = get_user_github_config(user) if current_github_config: @@ -164,7 +164,7 @@ if not state.demo: @web_client.get("/config/content_type/notion", response_class=HTMLResponse) @requires(["authenticated"], redirect="login_page") def notion_config_page(request: Request): - user = request.user.object if request.user.is_authenticated else None + user = request.user.object current_notion_config = get_user_notion_config(user) current_config = NotionContentConfig( @@ -184,7 +184,7 @@ if not state.demo: return templates.TemplateResponse("config.html", context={"request": request}) object = map_config_to_object(content_type) - user = request.user.object if request.user.is_authenticated else None + user = request.user.object config = object.objects.filter(user=user).first() if config == None: config = object.objects.create(user=user) diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 3c084c4f..7795d695 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -6,12 +6,13 @@ import logging from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional, Union, Any -from khoj.processor.conversation.gpt4all.utils import download_model # External Packages import torch -from khoj.utils.rawconfig import OfflineChatProcessorConfig +# Internal Packages +from khoj.processor.conversation.gpt4all.utils import download_model + logger = logging.getLogger(__name__) @@ -88,3 +89,4 @@ class GPT4AllProcessorModel: except ValueError as e: self.loaded_model = None logger.error(f"Error while loading offline chat model: {e}", exc_info=True) + raise e diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 0269a9e9..f6418fbd 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -10,7 +10,7 @@ from os import path import os from pathlib import Path import platform -import sys +import random from time import perf_counter import torch from typing import Optional, Union, TYPE_CHECKING @@ -254,6 +254,18 @@ def log_telemetry( return request_body +def get_device() -> torch.device: + """Get device to run model on""" + if torch.cuda.is_available(): + # Use CUDA GPU + return torch.device("cuda:0") + elif torch.backends.mps.is_available(): + # Use Apple M1 Metal Acceleration + return torch.device("mps") + else: + return torch.device("cpu") + + class ConversationCommand(str, Enum): Default = "default" General = "general" @@ -267,3 +279,29 @@ command_descriptions = { ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.", ConversationCommand.Help: "Display a help message with all available commands and other metadata.", } + + +def generate_random_name(): + # List of adjectives and nouns to choose from + adjectives = [ + "happy", + "irritated", + "annoyed", + "calm", + "brave", + "scared", + "energetic", + "chivalrous", + "kind", + "grumpy", + ] + nouns = ["dog", "cat", "falcon", "whale", "turtle", "rabbit", "hamster", "snake", "spider", "elephant"] + + # Select two random words from the lists + adjective = random.choice(adjectives) + noun = random.choice(nouns) + + # Combine the words to form a name + name = f"{adjective} {noun}" + + return name diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 40806c51..e92f19a7 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -1,7 +1,6 @@ # Standard Packages import threading from typing import List, Dict -from packaging import version from collections import defaultdict # External Packages @@ -11,7 +10,7 @@ from pathlib import Path # Internal Packages from khoj.utils import config as utils_config from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel -from khoj.utils.helpers import LRU +from khoj.utils.helpers import LRU, get_device from khoj.utils.rawconfig import FullConfig from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel @@ -35,12 +34,4 @@ telemetry: List[Dict[str, str]] = [] demo: bool = False khoj_version: str = None anonymous_mode: bool = False - -if torch.cuda.is_available(): - # Use CUDA GPU - device = torch.device("cuda:0") -elif version.parse(torch.__version__) >= version.parse("1.13.0.dev") and torch.backends.mps.is_available(): - # Use Apple M1 Metal Acceleration - device = torch.device("mps") -else: - device = torch.device("cpu") +device = get_device() diff --git a/tests/conftest.py b/tests/conftest.py index 12ac4f7b..aad20274 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,7 @@ from khoj.utils import state, fs_syncer from khoj.routers.indexer import configure_content from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from database.models import ( + KhojApiUser, LocalOrgConfig, LocalMarkdownConfig, LocalPlaintextConfig, @@ -76,13 +77,26 @@ def default_user2(): if KhojUser.objects.filter(username="default").exists(): return KhojUser.objects.get(username="default") - return UserFactory( + return KhojUser.objects.create( username="default", email="default@example.com", password="default", ) +@pytest.mark.django_db +@pytest.fixture +def api_user(default_user): + if KhojApiUser.objects.filter(user=default_user).exists(): + return KhojApiUser.objects.get(user=default_user) + + return KhojApiUser.objects.create( + user=default_user, + name="api-key", + token="kk-secret", + ) + + @pytest.fixture(scope="session") def search_models(search_config: SearchConfig): search_models = SearchModels() @@ -176,7 +190,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser): if os.getenv("OPENAI_API_KEY"): OpenAIProcessorConversationConfigFactory(user=default_user2) - state.anonymous_mode = True + state.anonymous_mode = False app = FastAPI() @@ -219,7 +233,7 @@ def fastapi_app(): def client( content_config: ContentConfig, search_config: SearchConfig, - default_user: KhojUser, + api_user: KhojApiUser, ): state.config.content_type = content_config state.config.search_type = search_config @@ -231,7 +245,7 @@ def client( OrgToJsonl, get_sample_data("org"), regenerate=False, - user=default_user, + user=api_user.user, ) state.content_index.image = image_search.setup( content_config.image, state.search_models.image_search, regenerate=False @@ -240,11 +254,11 @@ def client( PlaintextToJsonl, get_sample_data("plaintext"), regenerate=False, - user=default_user, + user=api_user.user, ) - ConversationProcessorConfigFactory(user=default_user) - state.anonymous_mode = True + ConversationProcessorConfigFactory(user=api_user.user) + state.anonymous_mode = False configure_routes(app) configure_middleware(app) @@ -253,13 +267,8 @@ def client( @pytest.fixture(scope="function") -def client_offline_chat( - search_config: SearchConfig, - content_config: ContentConfig, - default_user2: KhojUser, -): +def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser): # Initialize app state - state.config.content_type = md_content_config state.config.search_type = search_config state.SearchType = configure_search_types(state.config) @@ -269,9 +278,6 @@ def client_offline_chat( user=default_user2, ) - # Index Markdown Content for Search - state.search_models.image_search = image_search.initialize_model(search_config.image) - all_files = fs_syncer.collect_files(user=default_user2) configure_content( state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2 @@ -283,6 +289,8 @@ def client_offline_chat( state.anonymous_mode = True + app = FastAPI() + configure_routes(app) configure_middleware(app) app.mount("/static", StaticFiles(directory=web_directory), name="static") diff --git a/tests/helpers.py b/tests/helpers.py index 655c4435..2f2feddf 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -3,6 +3,7 @@ import os from database.models import ( KhojUser, + KhojApiUser, ConversationProcessorConfig, OfflineChatProcessorConversationConfig, OpenAIProcessorConversationConfig, @@ -20,6 +21,15 @@ class UserFactory(factory.django.DjangoModelFactory): uuid = factory.Faker("uuid4") +class ApiUserFactory(factory.django.DjangoModelFactory): + class Meta: + model = KhojApiUser + + user = None + name = factory.Faker("name") + token = factory.Faker("password") + + class ConversationProcessorConfigFactory(factory.django.DjangoModelFactory): class Meta: model = ConversationProcessorConfig diff --git a/tests/test_client.py b/tests/test_client.py index 1a6b1346..6818c2ba 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -22,49 +22,115 @@ from database.adapters import EmbeddingsAdapters # Test # ---------------------------------------------------------------------------------------------------- -def test_search_with_invalid_content_type(client): +@pytest.mark.django_db(transaction=True) +def test_search_with_no_auth_key(client): # Arrange user_query = quote("How to call Khoj from Emacs?") # Act - response = client.get(f"/api/search?q={user_query}&t=invalid_content_type") + response = client.get(f"/api/search?q={user_query}") + + # Assert + assert response.status_code == 403 + + +@pytest.mark.django_db(transaction=True) +def test_search_with_invalid_auth_key(client): + # Arrange + headers = {"Authorization": "Bearer invalid-token"} + user_query = quote("How to call Khoj from Emacs?") + + # Act + response = client.get(f"/api/search?q={user_query}", headers=headers) + + # Assert + assert response.status_code == 403 + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) +def test_search_with_invalid_content_type(client): + # Arrange + headers = {"Authorization": "Bearer kk-secret"} + user_query = quote("How to call Khoj from Emacs?") + + # Act + response = client.get(f"/api/search?q={user_query}&t=invalid_content_type", headers=headers) # Assert assert response.status_code == 422 # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) def test_search_with_valid_content_type(client): - for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion"]: + headers = {"Authorization": "Bearer kk-secret"} + for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion", "plaintext"]: # Act - response = client.get(f"/api/search?q=random&t={content_type}") + response = client.get(f"/api/search?q=random&t={content_type}", headers=headers) # Assert assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}" # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) +def test_index_update_with_no_auth_key(client): + # Arrange + files = get_sample_files_data() + + # Act + response = client.post("/api/v1/index/update", files=files) + + # Assert + assert response.status_code == 403 + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) +def test_index_update_with_invalid_auth_key(client): + # Arrange + files = get_sample_files_data() + headers = {"Authorization": "Bearer kk-invalid-token"} + + # Act + response = client.post("/api/v1/index/update", files=files, headers=headers) + + # Assert + assert response.status_code == 403 + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) def test_update_with_invalid_content_type(client): + # Arrange + headers = {"Authorization": "Bearer kk-secret"} + # Act - response = client.get(f"/api/update?t=invalid_content_type") + response = client.get(f"/api/update?t=invalid_content_type", headers=headers) # Assert assert response.status_code == 422 # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) def test_regenerate_with_invalid_content_type(client): + # Arrange + headers = {"Authorization": "Bearer kk-secret"} + # Act - response = client.get(f"/api/update?force=true&t=invalid_content_type") + response = client.get(f"/api/update?force=true&t=invalid_content_type", headers=headers) # Assert assert response.status_code == 422 # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) def test_index_update(client): # Arrange files = get_sample_files_data() - headers = {"x-api-key": "secret"} + headers = {"Authorization": "Bearer kk-secret"} # Act response = client.post("/api/v1/index/update", files=files, headers=headers) @@ -74,29 +140,33 @@ def test_index_update(client): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) def test_regenerate_with_valid_content_type(client): for content_type in ["all", "org", "markdown", "image", "pdf", "notion"]: # Arrange files = get_sample_files_data() - headers = {"x-api-key": "secret"} + headers = {"Authorization": "Bearer kk-secret"} # Act response = client.post(f"/api/v1/index/update?t={content_type}", files=files, headers=headers) + # Assert assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}" # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) def test_regenerate_with_github_fails_without_pat(client): # Act - response = client.get(f"/api/update?force=true&t=github") + headers = {"Authorization": "Bearer kk-secret"} + response = client.get(f"/api/update?force=true&t=github", headers=headers) # Arrange files = get_sample_files_data() - headers = {"x-api-key": "secret"} # Act response = client.post(f"/api/v1/index/update?t=github", files=files, headers=headers) + # Assert assert response.status_code == 200, f"Returned status: {response.status_code} for content type: github" @@ -116,16 +186,17 @@ def test_get_configured_types_via_api(client, sample_org_data): # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) -def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser): +def test_get_api_config_types(client, sample_org_data, default_user: KhojUser): # Arrange - text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2) + headers = {"Authorization": "Bearer kk-secret"} + text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user) # Act - response = client.get(f"/api/config/types") + response = client.get(f"/api/config/types", headers=headers) # Assert assert response.status_code == 200 - assert response.json() == ["all", "org", "image"] + assert response.json() == ["all", "org", "image", "plaintext"] # ---------------------------------------------------------------------------------------------------- @@ -135,6 +206,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI): state.SearchType = configure_search_types(config) original_config = state.config.content_type state.config.content_type = None + state.anonymous_mode = True configure_routes(fastapi_app) client = TestClient(fastapi_app) @@ -154,6 +226,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI): @pytest.mark.django_db(transaction=True) def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange + headers = {"Authorization": "Bearer kk-secret"} search_models.image_search = image_search.initialize_model(search_config.image) content_index.image = image_search.setup( content_config.image, search_models.image_search.image_encoder, regenerate=False @@ -166,7 +239,7 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear for query, expected_image_name in query_expected_image_pairs: # Act - response = client.get(f"/api/search?q={query}&n=1&t=image") + response = client.get(f"/api/search?q={query}&n=1&t=image", headers=headers) # Assert assert response.status_code == 200 @@ -179,13 +252,14 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) -def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser): +def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser): # Arrange - text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2) + headers = {"Authorization": "Bearer kk-secret"} + text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user) user_query = quote("How to git install application?") # Act - response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true") + response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true", headers=headers) # Assert assert response.status_code == 200 @@ -197,19 +271,20 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) def test_notes_search_with_only_filters( - client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data, default_user2: KhojUser + client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data, default_user: KhojUser ): # Arrange + headers = {"Authorization": "Bearer kk-secret"} text_search.setup( OrgToJsonl, sample_org_data, regenerate=False, - user=default_user2, + user=default_user, ) user_query = quote('+"Emacs" file:"*.org"') # Act - response = client.get(f"/api/search?q={user_query}&n=1&t=org") + response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers) # Assert assert response.status_code == 200 @@ -220,13 +295,14 @@ def test_notes_search_with_only_filters( # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) -def test_notes_search_with_include_filter(client, sample_org_data, default_user2: KhojUser): +def test_notes_search_with_include_filter(client, sample_org_data, default_user: KhojUser): # Arrange - text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2) + headers = {"Authorization": "Bearer kk-secret"} + text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user) user_query = quote('How to git install application? +"Emacs"') # Act - response = client.get(f"/api/search?q={user_query}&n=1&t=org") + response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers) # Assert assert response.status_code == 200 @@ -237,18 +313,19 @@ def test_notes_search_with_include_filter(client, sample_org_data, default_user2 # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) -def test_notes_search_with_exclude_filter(client, sample_org_data, default_user2: KhojUser): +def test_notes_search_with_exclude_filter(client, sample_org_data, default_user: KhojUser): # Arrange + headers = {"Authorization": "Bearer kk-secret"} text_search.setup( OrgToJsonl, sample_org_data, regenerate=False, - user=default_user2, + user=default_user, ) user_query = quote('How to git install application? -"clone"') # Act - response = client.get(f"/api/search?q={user_query}&n=1&t=org") + response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers) # Assert assert response.status_code == 200 @@ -261,16 +338,17 @@ def test_notes_search_with_exclude_filter(client, sample_org_data, default_user2 @pytest.mark.django_db(transaction=True) def test_different_user_data_not_accessed(client, sample_org_data, default_user: KhojUser): # Arrange + headers = {"Authorization": "Bearer kk-token"} # Token for default_user2 text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user) user_query = quote("How to git install application?") # Act - response = client.get(f"/api/search?q={user_query}&n=1&t=org") + response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers) # Assert - assert response.status_code == 200 + assert response.status_code == 403 # assert actual response has no data as the default_user is different from the user making the query (anonymous) - assert len(response.json()) == 0 + assert len(response.json()) == 1 and response.json()["detail"] == "Forbidden" def get_sample_files_data():