-
-
- Host
-
+
+
+
+
+
+ Server URL
+
+
+
+
+
+
+
+
+ Access Key
+
+
+
+
+
-
-
-
-
-
-
- Files
-
+
+
+
+
+
+ Files
+
+
+
+
+
+
-
+
-
-
-
-
-
-
-
- Folders
-
+
+
+
+
+
+ Folders
+
+
+
+
+
+
-
-
-
-
-
+
+
+
-
-
-
+
+
+
@@ -93,7 +111,7 @@
body {
display: grid;
grid-template-columns: 1fr;
- grid-template-rows: 1fr auto auto auto minmax(80px, 100%);
+ grid-template-rows: 1fr auto;
font-size: small!important;
}
body > * {
@@ -104,8 +122,7 @@
body {
display: grid;
grid-template-columns: 1fr min(70vw, 100%) 1fr;
- grid-template-rows: 1fr auto auto auto minmax(80px, 100%);
- padding-top: 60vw;
+ grid-template-rows: 80px auto;
}
body > * {
grid-column: 2;
@@ -126,11 +143,6 @@
margin: 10px;
}
- div.page {
- padding: 0px;
- margin: 0px;
- }
-
svg {
transition: transform 0.3s ease-in-out;
}
@@ -167,18 +179,18 @@
}
}
- #khoj-host-url {
+ .card-input {
padding: 4px;
box-shadow: 0 0 2px 1px rgba(0, 0, 0, 0.2);
border: none;
+ width: 450px;
}
.card {
display: grid;
- /* grid-template-rows: repeat(3, 1fr); */
gap: 8px;
padding: 24px 16px;
- width: 100%;
+ width: 450px;
background: white;
border: 1px solid rgb(229, 229, 229);
border-radius: 4px;
@@ -188,15 +200,15 @@
.section-cards {
display: grid;
- grid-template-columns: repeat(1, 1fr);
gap: 16px;
- justify-items: start;
+ justify-items: center;
margin: 0;
- width: auto;
}
-
- div.configuration {
- width: auto;
+ .section-action-row {
+ display: grid;
+ grid-auto-flow: column;
+ gap: 16px;
+ height: fit-content;
}
.card-title-row {
@@ -302,7 +314,6 @@
}
div.content-name {
- width: 500px;
overflow-wrap: break-word;
}
@@ -347,6 +358,12 @@
background-color: #ffcc00;
box-shadow: 0px 3px 0px #f9f5de;
}
+ .sync-force-toggle {
+ align-content: center;
+ display: grid;
+ grid-auto-flow: column;
+ gap: 4px;
+ }
{% endblock %}
diff --git a/src/khoj/main.py b/src/khoj/main.py
index 804e71e5..8fe40e76 100644
--- a/src/khoj/main.py
+++ b/src/khoj/main.py
@@ -98,9 +98,10 @@ def run():
# Mount Django and Static Files
app.mount("/django", django_app, name="django")
- if not os.path.exists("static"):
- os.mkdir("static")
- app.mount("/static", StaticFiles(directory="static"), name="static")
+ static_dir = "static"
+ if not os.path.exists(static_dir):
+ os.mkdir(static_dir)
+ app.mount(f"/{static_dir}", StaticFiles(directory=static_dir), name=static_dir)
# Configure Middleware
configure_middleware(app)
diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/gpt4all/utils.py
index d5201780..cd9bc9e2 100644
--- a/src/khoj/processor/conversation/gpt4all/utils.py
+++ b/src/khoj/processor/conversation/gpt4all/utils.py
@@ -6,17 +6,23 @@ logger = logging.getLogger(__name__)
def download_model(model_name: str):
try:
- from gpt4all import GPT4All
+ import gpt4all
except ModuleNotFoundError as e:
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
- # Use GPU for Chat Model, if available
- try:
- model = GPT4All(model_name=model_name, device="gpu")
- logger.debug("Loaded chat model to GPU.")
- except ValueError:
- model = GPT4All(model_name=model_name)
- logger.debug("Loaded chat model to CPU.")
+ # Download the chat model
+ chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
- return model
+ # Decide whether to load model to GPU or CPU
+ try:
+ # Check if machine has GPU and GPU has enough free memory to load the chat model
+ device = "gpu" if gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"]) else "cpu"
+ except ValueError:
+ device = "cpu"
+
+ # Now load the downloaded chat model onto appropriate device
+ chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False)
+ logger.debug(f"Loaded chat model to {device.upper()}.")
+
+ return chat_model
diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py
index f0e2df77..fbcddb67 100644
--- a/src/khoj/processor/embeddings.py
+++ b/src/khoj/processor/embeddings.py
@@ -1,28 +1,17 @@
from typing import List
-import torch
from langchain.embeddings import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
+from khoj.utils.helpers import get_device
from khoj.utils.rawconfig import SearchResponse
class EmbeddingsModel:
def __init__(self):
- self.model_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
- encode_kwargs = {"normalize_embeddings": True}
- # encode_kwargs = {}
-
- if torch.cuda.is_available():
- # Use CUDA GPU
- device = torch.device("cuda:0")
- elif torch.backends.mps.is_available():
- # Use Apple M1 Metal Acceleration
- device = torch.device("mps")
- else:
- device = torch.device("cpu")
-
- model_kwargs = {"device": device}
+ self.model_name = "thenlper/gte-small"
+ encode_kwargs = {"normalize_embeddings": True, "show_progress_bar": True}
+ model_kwargs = {"device": get_device()}
self.embeddings_model = HuggingFaceEmbeddings(
model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
)
@@ -37,19 +26,7 @@ class EmbeddingsModel:
class CrossEncoderModel:
def __init__(self):
self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
-
- if torch.cuda.is_available():
- # Use CUDA GPU
- device = torch.device("cuda:0")
-
- elif torch.backends.mps.is_available():
- # Use Apple M1 Metal Acceleration
- device = torch.device("mps")
-
- else:
- device = torch.device("cpu")
-
- self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=device)
+ self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
def predict(self, query, hits: List[SearchResponse]):
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index b7ba66b6..4984ea4c 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -126,21 +126,21 @@ if not state.demo:
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
@api.get("/config/data", response_model=FullConfig)
- @requires(["authenticated"], redirect="login_page")
+ @requires(["authenticated"])
def get_config_data(request: Request):
- user = request.user.object if request.user.is_authenticated else None
- enabled_content = EmbeddingsAdapters.get_unique_file_types(user)
+ user = request.user.object
+ EmbeddingsAdapters.get_unique_file_types(user)
return state.config
@api.post("/config/data")
- @requires(["authenticated"], redirect="login_page")
+ @requires(["authenticated"])
async def set_config_data(
request: Request,
updated_config: FullConfig,
client: Optional[str] = None,
):
- user = request.user.object if request.user.is_authenticated else None
+ user = request.user.object
await map_config_to_db(updated_config, user)
configuration_update_metadata = {}
@@ -167,7 +167,7 @@ if not state.demo:
return state.config
@api.post("/config/data/content_type/github", status_code=200)
- @requires(["authenticated"], redirect="login_page")
+ @requires(["authenticated"])
async def set_content_config_github_data(
request: Request,
updated_config: Union[GithubContentConfig, None],
@@ -175,7 +175,7 @@ if not state.demo:
):
_initialize_config()
- user = request.user.object if request.user.is_authenticated else None
+ user = request.user.object
await adapters.set_user_github_config(
user=user,
@@ -194,7 +194,7 @@ if not state.demo:
return {"status": "ok"}
@api.post("/config/data/content_type/notion", status_code=200)
- @requires(["authenticated"], redirect="login_page")
+ @requires(["authenticated"])
async def set_content_config_notion_data(
request: Request,
updated_config: Union[NotionContentConfig, None],
@@ -202,7 +202,7 @@ if not state.demo:
):
_initialize_config()
- user = request.user.object if request.user.is_authenticated else None
+ user = request.user.object
await adapters.set_notion_config(
user=user,
@@ -220,13 +220,13 @@ if not state.demo:
return {"status": "ok"}
@api.post("/delete/config/data/content_type/{content_type}", status_code=200)
- @requires(["authenticated"], redirect="login_page")
+ @requires(["authenticated"])
async def remove_content_config_data(
request: Request,
content_type: str,
client: Optional[str] = None,
):
- user = request.user.object if request.user.is_authenticated else None
+ user = request.user.object
update_telemetry_state(
request=request,
@@ -247,7 +247,7 @@ if not state.demo:
return {"status": "ok"}
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
- @requires(["authenticated"], redirect="login_page")
+ @requires(["authenticated"])
async def remove_processor_conversation_config_data(
request: Request,
client: Optional[str] = None,
@@ -267,7 +267,7 @@ if not state.demo:
return {"status": "ok"}
@api.post("/config/data/content_type/{content_type}", status_code=200)
- @requires(["authenticated"], redirect="login_page")
+ @requires(["authenticated"])
async def set_content_config_data(
request: Request,
content_type: str,
@@ -276,7 +276,7 @@ if not state.demo:
):
_initialize_config()
- user = request.user.object if request.user.is_authenticated else None
+ user = request.user.object
content_object = map_config_to_object(content_type)
await adapters.set_text_content_config(user, content_object, updated_config)
@@ -292,7 +292,7 @@ if not state.demo:
return {"status": "ok"}
@api.post("/config/data/processor/conversation/openai", status_code=200)
- @requires(["authenticated"], redirect="login_page")
+ @requires(["authenticated"])
async def set_processor_openai_config_data(
request: Request,
updated_config: Union[OpenAIProcessorConfig, None],
@@ -315,6 +315,7 @@ if not state.demo:
return {"status": "ok"}
@api.post("/config/data/processor/conversation/offline_chat", status_code=200)
+ @requires(["authenticated"])
async def set_processor_enable_offline_chat_config_data(
request: Request,
enable_offline_chat: bool,
@@ -323,24 +324,29 @@ if not state.demo:
):
user = request.user.object
- if enable_offline_chat:
- conversation_config = ConversationProcessorConfig(
- offline_chat=OfflineChatProcessorConfig(
- enable_offline_chat=enable_offline_chat,
- chat_model=offline_chat_model,
+ try:
+ if enable_offline_chat:
+ conversation_config = ConversationProcessorConfig(
+ offline_chat=OfflineChatProcessorConfig(
+ enable_offline_chat=enable_offline_chat,
+ chat_model=offline_chat_model,
+ )
)
- )
- await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config)
+ await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config)
- offline_chat = await ConversationAdapters.get_offline_chat(user)
- chat_model = offline_chat.chat_model
- if state.gpt4all_processor_config is None:
- state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
+ offline_chat = await ConversationAdapters.get_offline_chat(user)
+ chat_model = offline_chat.chat_model
+ if state.gpt4all_processor_config is None:
+ state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
- else:
- await sync_to_async(ConversationAdapters.clear_offline_chat_conversation_config)(user)
- state.gpt4all_processor_config = None
+ else:
+ await sync_to_async(ConversationAdapters.clear_offline_chat_conversation_config)(user)
+ state.gpt4all_processor_config = None
+
+ except Exception as e:
+ logger.error(f"Error updating offline chat config: {e}", exc_info=True)
+ return {"status": "error", "message": str(e)}
update_telemetry_state(
request=request,
@@ -360,11 +366,11 @@ def get_default_config_data():
@api.get("/config/types", response_model=List[str])
-@requires(["authenticated"], redirect="login_page")
+@requires(["authenticated"])
def get_config_types(
request: Request,
):
- user = request.user.object if request.user.is_authenticated else None
+ user = request.user.object
enabled_file_types = EmbeddingsAdapters.get_unique_file_types(user)
@@ -382,7 +388,7 @@ def get_config_types(
@api.get("/search", response_model=List[SearchResponse])
-@requires(["authenticated"], redirect="login_page")
+@requires(["authenticated"])
async def search(
q: str,
request: Request,
@@ -396,7 +402,7 @@ async def search(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
- user = request.user.object if request.user.is_authenticated else None
+ user = request.user.object
start_time = time.time()
# Run validation checks
@@ -513,7 +519,7 @@ async def search(
@api.get("/update")
-@requires(["authenticated"], redirect="login_page")
+@requires(["authenticated"])
def update(
request: Request,
t: Optional[SearchType] = None,
@@ -523,7 +529,7 @@ def update(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
- user = request.user.object if request.user.is_authenticated else None
+ user = request.user.object
if not state.config:
error_msg = f"๐จ Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
logger.warning(error_msg)
@@ -557,7 +563,7 @@ def update(
@api.get("/chat/history")
-@requires(["authenticated"], redirect="login_page")
+@requires(["authenticated"])
def chat_history(
request: Request,
client: Optional[str] = None,
@@ -585,7 +591,7 @@ def chat_history(
@api.get("/chat/options", response_class=Response)
-@requires(["authenticated"], redirect="login_page")
+@requires(["authenticated"])
async def chat_options(
request: Request,
client: Optional[str] = None,
@@ -610,7 +616,7 @@ async def chat_options(
@api.get("/chat", response_class=Response)
-@requires(["authenticated"], redirect="login_page")
+@requires(["authenticated"])
async def chat(
request: Request,
q: str,
diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py
index 8c767d8f..ab1964b8 100644
--- a/src/khoj/routers/auth.py
+++ b/src/khoj/routers/auth.py
@@ -1,22 +1,29 @@
+# Standard Packages
import logging
-import json
import os
+from typing import Optional
+
+# External Packages
from fastapi import APIRouter
from starlette.config import Config
from starlette.requests import Request
from starlette.responses import HTMLResponse, RedirectResponse, Response
+from starlette.authentication import requires
from authlib.integrations.starlette_client import OAuth, OAuthError
from google.oauth2 import id_token
from google.auth.transport import requests as google_requests
-from database.adapters import get_or_create_user
+# Internal Packages
+from database.adapters import get_khoj_tokens, get_or_create_user, create_khoj_token, delete_khoj_token
+from khoj.utils import state
+
logger = logging.getLogger(__name__)
auth_router = APIRouter()
-if not os.environ.get("GOOGLE_CLIENT_ID") or not os.environ.get("GOOGLE_CLIENT_SECRET"):
+if not state.anonymous_mode and not (os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET")):
logger.info("Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET environment variables to use Google OAuth")
else:
config = Config(environ=os.environ)
@@ -39,6 +46,31 @@ async def login(request: Request):
return await oauth.google.authorize_redirect(request, redirect_uri)
+@auth_router.post("/token")
+@requires(["authenticated"], redirect="login_page")
+async def generate_token(request: Request, token_name: Optional[str] = None) -> str:
+ "Generate API token for given user"
+ if token_name:
+ return await create_khoj_token(user=request.user.object, name=token_name)
+ else:
+ return await create_khoj_token(user=request.user.object)
+
+
+@auth_router.get("/token")
+@requires(["authenticated"], redirect="login_page")
+def get_tokens(request: Request):
+ "Get API tokens enabled for given user"
+ tokens = get_khoj_tokens(user=request.user.object)
+ return tokens
+
+
+@auth_router.delete("/token")
+@requires(["authenticated"], redirect="login_page")
+async def delete_token(request: Request, token: str) -> str:
+ "Delete API token for given user"
+ return await delete_khoj_token(user=request.user.object, token=token)
+
+
@auth_router.post("/redirect")
async def auth(request: Request):
form = await request.form()
diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py
index 1125e653..e7df65a2 100644
--- a/src/khoj/routers/indexer.py
+++ b/src/khoj/routers/indexer.py
@@ -4,9 +4,9 @@ from typing import Optional, Union, Dict
import asyncio
# External Packages
-from fastapi import APIRouter, HTTPException, Header, Request, Response, UploadFile
+from fastapi import APIRouter, Header, Request, Response, UploadFile
from pydantic import BaseModel
-from khoj.routers.helpers import update_telemetry_state
+from starlette.authentication import requires
# Internal Packages
from khoj.utils import state, constants
@@ -17,6 +17,7 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from khoj.search_type import text_search, image_search
+from khoj.routers.helpers import update_telemetry_state
from khoj.utils.yaml import save_config_to_file_updated_state
from khoj.utils.config import SearchModels
from khoj.utils.helpers import LRU, get_file_type
@@ -57,10 +58,10 @@ class IndexerInput(BaseModel):
@indexer.post("/update")
+@requires(["authenticated"])
async def update(
request: Request,
files: list[UploadFile],
- x_api_key: str = Header(None),
force: bool = False,
t: Optional[Union[state.SearchType, str]] = None,
client: Optional[str] = None,
@@ -68,9 +69,7 @@ async def update(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
- user = request.user.object if request.user.is_authenticated else None
- if x_api_key != "secret":
- raise HTTPException(status_code=401, detail="Invalid API Key")
+ user = request.user.object
try:
logger.info(f"๐ฌ Updating content index via API call by {client} client")
org_files: Dict[str, str] = {}
diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py
index 333d89fa..06f43430 100644
--- a/src/khoj/routers/web_client.py
+++ b/src/khoj/routers/web_client.py
@@ -135,7 +135,7 @@ if not state.demo:
@web_client.get("/config/content_type/github", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def github_config_page(request: Request):
- user = request.user.object if request.user.is_authenticated else None
+ user = request.user.object
current_github_config = get_user_github_config(user)
if current_github_config:
@@ -164,7 +164,7 @@ if not state.demo:
@web_client.get("/config/content_type/notion", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def notion_config_page(request: Request):
- user = request.user.object if request.user.is_authenticated else None
+ user = request.user.object
current_notion_config = get_user_notion_config(user)
current_config = NotionContentConfig(
@@ -184,7 +184,7 @@ if not state.demo:
return templates.TemplateResponse("config.html", context={"request": request})
object = map_config_to_object(content_type)
- user = request.user.object if request.user.is_authenticated else None
+ user = request.user.object
config = object.objects.filter(user=user).first()
if config == None:
config = object.objects.create(user=user)
diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py
index 3c084c4f..7795d695 100644
--- a/src/khoj/utils/config.py
+++ b/src/khoj/utils/config.py
@@ -6,12 +6,13 @@ import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Union, Any
-from khoj.processor.conversation.gpt4all.utils import download_model
# External Packages
import torch
-from khoj.utils.rawconfig import OfflineChatProcessorConfig
+# Internal Packages
+from khoj.processor.conversation.gpt4all.utils import download_model
+
logger = logging.getLogger(__name__)
@@ -88,3 +89,4 @@ class GPT4AllProcessorModel:
except ValueError as e:
self.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
+ raise e
diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py
index 0269a9e9..f6418fbd 100644
--- a/src/khoj/utils/helpers.py
+++ b/src/khoj/utils/helpers.py
@@ -10,7 +10,7 @@ from os import path
import os
from pathlib import Path
import platform
-import sys
+import random
from time import perf_counter
import torch
from typing import Optional, Union, TYPE_CHECKING
@@ -254,6 +254,18 @@ def log_telemetry(
return request_body
+def get_device() -> torch.device:
+ """Get device to run model on"""
+ if torch.cuda.is_available():
+ # Use CUDA GPU
+ return torch.device("cuda:0")
+ elif torch.backends.mps.is_available():
+ # Use Apple M1 Metal Acceleration
+ return torch.device("mps")
+ else:
+ return torch.device("cpu")
+
+
class ConversationCommand(str, Enum):
Default = "default"
General = "general"
@@ -267,3 +279,29 @@ command_descriptions = {
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
}
+
+
+def generate_random_name():
+ # List of adjectives and nouns to choose from
+ adjectives = [
+ "happy",
+ "irritated",
+ "annoyed",
+ "calm",
+ "brave",
+ "scared",
+ "energetic",
+ "chivalrous",
+ "kind",
+ "grumpy",
+ ]
+ nouns = ["dog", "cat", "falcon", "whale", "turtle", "rabbit", "hamster", "snake", "spider", "elephant"]
+
+ # Select two random words from the lists
+ adjective = random.choice(adjectives)
+ noun = random.choice(nouns)
+
+ # Combine the words to form a name
+ name = f"{adjective} {noun}"
+
+ return name
diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py
index 40806c51..e92f19a7 100644
--- a/src/khoj/utils/state.py
+++ b/src/khoj/utils/state.py
@@ -1,7 +1,6 @@
# Standard Packages
import threading
from typing import List, Dict
-from packaging import version
from collections import defaultdict
# External Packages
@@ -11,7 +10,7 @@ from pathlib import Path
# Internal Packages
from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
-from khoj.utils.helpers import LRU
+from khoj.utils.helpers import LRU, get_device
from khoj.utils.rawconfig import FullConfig
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
@@ -35,12 +34,4 @@ telemetry: List[Dict[str, str]] = []
demo: bool = False
khoj_version: str = None
anonymous_mode: bool = False
-
-if torch.cuda.is_available():
- # Use CUDA GPU
- device = torch.device("cuda:0")
-elif version.parse(torch.__version__) >= version.parse("1.13.0.dev") and torch.backends.mps.is_available():
- # Use Apple M1 Metal Acceleration
- device = torch.device("mps")
-else:
- device = torch.device("cpu")
+device = get_device()
diff --git a/tests/conftest.py b/tests/conftest.py
index 12ac4f7b..aad20274 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -28,6 +28,7 @@ from khoj.utils import state, fs_syncer
from khoj.routers.indexer import configure_content
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from database.models import (
+ KhojApiUser,
LocalOrgConfig,
LocalMarkdownConfig,
LocalPlaintextConfig,
@@ -76,13 +77,26 @@ def default_user2():
if KhojUser.objects.filter(username="default").exists():
return KhojUser.objects.get(username="default")
- return UserFactory(
+ return KhojUser.objects.create(
username="default",
email="default@example.com",
password="default",
)
+@pytest.mark.django_db
+@pytest.fixture
+def api_user(default_user):
+ if KhojApiUser.objects.filter(user=default_user).exists():
+ return KhojApiUser.objects.get(user=default_user)
+
+ return KhojApiUser.objects.create(
+ user=default_user,
+ name="api-key",
+ token="kk-secret",
+ )
+
+
@pytest.fixture(scope="session")
def search_models(search_config: SearchConfig):
search_models = SearchModels()
@@ -176,7 +190,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser):
if os.getenv("OPENAI_API_KEY"):
OpenAIProcessorConversationConfigFactory(user=default_user2)
- state.anonymous_mode = True
+ state.anonymous_mode = False
app = FastAPI()
@@ -219,7 +233,7 @@ def fastapi_app():
def client(
content_config: ContentConfig,
search_config: SearchConfig,
- default_user: KhojUser,
+ api_user: KhojApiUser,
):
state.config.content_type = content_config
state.config.search_type = search_config
@@ -231,7 +245,7 @@ def client(
OrgToJsonl,
get_sample_data("org"),
regenerate=False,
- user=default_user,
+ user=api_user.user,
)
state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False
@@ -240,11 +254,11 @@ def client(
PlaintextToJsonl,
get_sample_data("plaintext"),
regenerate=False,
- user=default_user,
+ user=api_user.user,
)
- ConversationProcessorConfigFactory(user=default_user)
- state.anonymous_mode = True
+ ConversationProcessorConfigFactory(user=api_user.user)
+ state.anonymous_mode = False
configure_routes(app)
configure_middleware(app)
@@ -253,13 +267,8 @@ def client(
@pytest.fixture(scope="function")
-def client_offline_chat(
- search_config: SearchConfig,
- content_config: ContentConfig,
- default_user2: KhojUser,
-):
+def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state
- state.config.content_type = md_content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
@@ -269,9 +278,6 @@ def client_offline_chat(
user=default_user2,
)
- # Index Markdown Content for Search
- state.search_models.image_search = image_search.initialize_model(search_config.image)
-
all_files = fs_syncer.collect_files(user=default_user2)
configure_content(
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
@@ -283,6 +289,8 @@ def client_offline_chat(
state.anonymous_mode = True
+ app = FastAPI()
+
configure_routes(app)
configure_middleware(app)
app.mount("/static", StaticFiles(directory=web_directory), name="static")
diff --git a/tests/helpers.py b/tests/helpers.py
index 655c4435..2f2feddf 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -3,6 +3,7 @@ import os
from database.models import (
KhojUser,
+ KhojApiUser,
ConversationProcessorConfig,
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig,
@@ -20,6 +21,15 @@ class UserFactory(factory.django.DjangoModelFactory):
uuid = factory.Faker("uuid4")
+class ApiUserFactory(factory.django.DjangoModelFactory):
+ class Meta:
+ model = KhojApiUser
+
+ user = None
+ name = factory.Faker("name")
+ token = factory.Faker("password")
+
+
class ConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
class Meta:
model = ConversationProcessorConfig
diff --git a/tests/test_client.py b/tests/test_client.py
index 1a6b1346..6818c2ba 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -22,49 +22,115 @@ from database.adapters import EmbeddingsAdapters
# Test
# ----------------------------------------------------------------------------------------------------
-def test_search_with_invalid_content_type(client):
+@pytest.mark.django_db(transaction=True)
+def test_search_with_no_auth_key(client):
# Arrange
user_query = quote("How to call Khoj from Emacs?")
# Act
- response = client.get(f"/api/search?q={user_query}&t=invalid_content_type")
+ response = client.get(f"/api/search?q={user_query}")
+
+ # Assert
+ assert response.status_code == 403
+
+
+@pytest.mark.django_db(transaction=True)
+def test_search_with_invalid_auth_key(client):
+ # Arrange
+ headers = {"Authorization": "Bearer invalid-token"}
+ user_query = quote("How to call Khoj from Emacs?")
+
+ # Act
+ response = client.get(f"/api/search?q={user_query}", headers=headers)
+
+ # Assert
+ assert response.status_code == 403
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
+def test_search_with_invalid_content_type(client):
+ # Arrange
+ headers = {"Authorization": "Bearer kk-secret"}
+ user_query = quote("How to call Khoj from Emacs?")
+
+ # Act
+ response = client.get(f"/api/search?q={user_query}&t=invalid_content_type", headers=headers)
# Assert
assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
def test_search_with_valid_content_type(client):
- for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion"]:
+ headers = {"Authorization": "Bearer kk-secret"}
+ for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion", "plaintext"]:
# Act
- response = client.get(f"/api/search?q=random&t={content_type}")
+ response = client.get(f"/api/search?q=random&t={content_type}", headers=headers)
# Assert
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
+def test_index_update_with_no_auth_key(client):
+ # Arrange
+ files = get_sample_files_data()
+
+ # Act
+ response = client.post("/api/v1/index/update", files=files)
+
+ # Assert
+ assert response.status_code == 403
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
+def test_index_update_with_invalid_auth_key(client):
+ # Arrange
+ files = get_sample_files_data()
+ headers = {"Authorization": "Bearer kk-invalid-token"}
+
+ # Act
+ response = client.post("/api/v1/index/update", files=files, headers=headers)
+
+ # Assert
+ assert response.status_code == 403
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
def test_update_with_invalid_content_type(client):
+ # Arrange
+ headers = {"Authorization": "Bearer kk-secret"}
+
# Act
- response = client.get(f"/api/update?t=invalid_content_type")
+ response = client.get(f"/api/update?t=invalid_content_type", headers=headers)
# Assert
assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
def test_regenerate_with_invalid_content_type(client):
+ # Arrange
+ headers = {"Authorization": "Bearer kk-secret"}
+
# Act
- response = client.get(f"/api/update?force=true&t=invalid_content_type")
+ response = client.get(f"/api/update?force=true&t=invalid_content_type", headers=headers)
# Assert
assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
def test_index_update(client):
# Arrange
files = get_sample_files_data()
- headers = {"x-api-key": "secret"}
+ headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
@@ -74,29 +140,33 @@ def test_index_update(client):
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
def test_regenerate_with_valid_content_type(client):
for content_type in ["all", "org", "markdown", "image", "pdf", "notion"]:
# Arrange
files = get_sample_files_data()
- headers = {"x-api-key": "secret"}
+ headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.post(f"/api/v1/index/update?t={content_type}", files=files, headers=headers)
+
# Assert
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
def test_regenerate_with_github_fails_without_pat(client):
# Act
- response = client.get(f"/api/update?force=true&t=github")
+ headers = {"Authorization": "Bearer kk-secret"}
+ response = client.get(f"/api/update?force=true&t=github", headers=headers)
# Arrange
files = get_sample_files_data()
- headers = {"x-api-key": "secret"}
# Act
response = client.post(f"/api/v1/index/update?t=github", files=files, headers=headers)
+
# Assert
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: github"
@@ -116,16 +186,17 @@ def test_get_configured_types_via_api(client, sample_org_data):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
-def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser):
+def test_get_api_config_types(client, sample_org_data, default_user: KhojUser):
# Arrange
- text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
+ headers = {"Authorization": "Bearer kk-secret"}
+ text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user)
# Act
- response = client.get(f"/api/config/types")
+ response = client.get(f"/api/config/types", headers=headers)
# Assert
assert response.status_code == 200
- assert response.json() == ["all", "org", "image"]
+ assert response.json() == ["all", "org", "image", "plaintext"]
# ----------------------------------------------------------------------------------------------------
@@ -135,6 +206,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
state.SearchType = configure_search_types(config)
original_config = state.config.content_type
state.config.content_type = None
+ state.anonymous_mode = True
configure_routes(fastapi_app)
client = TestClient(fastapi_app)
@@ -154,6 +226,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
@pytest.mark.django_db(transaction=True)
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange
+ headers = {"Authorization": "Bearer kk-secret"}
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
@@ -166,7 +239,7 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
for query, expected_image_name in query_expected_image_pairs:
# Act
- response = client.get(f"/api/search?q={query}&n=1&t=image")
+ response = client.get(f"/api/search?q={query}&n=1&t=image", headers=headers)
# Assert
assert response.status_code == 200
@@ -179,13 +252,14 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
-def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser):
+def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
# Arrange
- text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
+ headers = {"Authorization": "Bearer kk-secret"}
+ text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user)
user_query = quote("How to git install application?")
# Act
- response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true")
+ response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true", headers=headers)
# Assert
assert response.status_code == 200
@@ -197,19 +271,20 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_notes_search_with_only_filters(
- client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data, default_user2: KhojUser
+ client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data, default_user: KhojUser
):
# Arrange
+ headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(
OrgToJsonl,
sample_org_data,
regenerate=False,
- user=default_user2,
+ user=default_user,
)
user_query = quote('+"Emacs" file:"*.org"')
# Act
- response = client.get(f"/api/search?q={user_query}&n=1&t=org")
+ response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
# Assert
assert response.status_code == 200
@@ -220,13 +295,14 @@ def test_notes_search_with_only_filters(
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
-def test_notes_search_with_include_filter(client, sample_org_data, default_user2: KhojUser):
+def test_notes_search_with_include_filter(client, sample_org_data, default_user: KhojUser):
# Arrange
- text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
+ headers = {"Authorization": "Bearer kk-secret"}
+ text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user)
user_query = quote('How to git install application? +"Emacs"')
# Act
- response = client.get(f"/api/search?q={user_query}&n=1&t=org")
+ response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
# Assert
assert response.status_code == 200
@@ -237,18 +313,19 @@ def test_notes_search_with_include_filter(client, sample_org_data, default_user2
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
-def test_notes_search_with_exclude_filter(client, sample_org_data, default_user2: KhojUser):
+def test_notes_search_with_exclude_filter(client, sample_org_data, default_user: KhojUser):
# Arrange
+ headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(
OrgToJsonl,
sample_org_data,
regenerate=False,
- user=default_user2,
+ user=default_user,
)
user_query = quote('How to git install application? -"clone"')
# Act
- response = client.get(f"/api/search?q={user_query}&n=1&t=org")
+ response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
# Assert
assert response.status_code == 200
@@ -261,16 +338,17 @@ def test_notes_search_with_exclude_filter(client, sample_org_data, default_user2
@pytest.mark.django_db(transaction=True)
def test_different_user_data_not_accessed(client, sample_org_data, default_user: KhojUser):
# Arrange
+ headers = {"Authorization": "Bearer kk-token"} # Token for default_user2
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user)
user_query = quote("How to git install application?")
# Act
- response = client.get(f"/api/search?q={user_query}&n=1&t=org")
+ response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
# Assert
- assert response.status_code == 200
+ assert response.status_code == 403
# assert actual response has no data as the default_user is different from the user making the query (anonymous)
- assert len(response.json()) == 0
+ assert len(response.json()) == 1 and response.json()["detail"] == "Forbidden"
def get_sample_files_data():