Short-circuit API rate limiter for unauthenticated users (#607)

### Major
- Short-circuit API rate limiter for unauthenticated user
  Calls by unauthenticated users were failing at API rate limiter as it
  failed to access user info object. This is a bug.
  
  API rate limiter should short-circuit for unauthenicated users so a
  proper Forbidden response can be returned by API
  
  Add regression test to verify that unauthenticated users get 403
  response when calling the /chat API endpoint
  
### Minor
- Remove trailing slash to normalize khoj url in obsidian plugin settings
- Move used /api/config API controllers into separate module
- Delete unused /api/beta API endpoint
- Fix error message rendering in khoj.el, khoj obsidian chat
- Handle deprecation warnings for subscribe renew date, langchain, pydantic & logger.warn
This commit is contained in:
Debanjum 2024-01-17 00:59:52 +05:30 committed by GitHub
commit 4d30f7d1d9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 377 additions and 410 deletions

View file

@ -55,7 +55,7 @@ dependencies = [
"torch == 2.0.1", "torch == 2.0.1",
"uvicorn == 0.17.6", "uvicorn == 0.17.6",
"aiohttp ~= 3.9.0", "aiohttp ~= 3.9.0",
"langchain >= 0.0.331", "langchain <= 0.2.0",
"requests >= 2.26.0", "requests >= 2.26.0",
"bs4 >= 0.0.1", "bs4 >= 0.0.1",
"anyio == 3.7.1", "anyio == 3.7.1",

View file

@ -348,7 +348,7 @@ Auto invokes setup steps on calling main entrypoint."
t t
;; else general check via ping to khoj-server-url ;; else general check via ping to khoj-server-url
(if (ignore-errors (if (ignore-errors
(url-retrieve-synchronously (format "%s/api/config/data/default" khoj-server-url))) (url-retrieve-synchronously (format "%s/api/health" khoj-server-url)))
;; Successful ping to non-emacs khoj server indicates it is started and ready. ;; Successful ping to non-emacs khoj server indicates it is started and ready.
;; So update ready state tracker variable (and implicitly return true for started) ;; So update ready state tracker variable (and implicitly return true for started)
(setq khoj--server-ready? t) (setq khoj--server-ready? t)
@ -432,7 +432,7 @@ Auto invokes setup steps on calling main entrypoint."
(khoj--delete-open-network-connections-to-server) (khoj--delete-open-network-connections-to-server)
(with-current-buffer (current-buffer) (with-current-buffer (current-buffer)
(search-forward "\n\n" nil t) (search-forward "\n\n" nil t)
(message "khoj.el: Failed to %supdate %s content index. Status: %s%s" (message "khoj.el: Failed to %supdate %scontent index. Status: %s%s"
(if force "force " "") (if force "force " "")
(if content-type (format "%s " content-type) "all") (if content-type (format "%s " content-type) "all")
(string-trim (format "%s %s" (nth 1 (nth 1 status)) (nth 2 (nth 1 status)))) (string-trim (format "%s %s" (nth 1 (nth 1 status)) (nth 2 (nth 1 status))))
@ -603,22 +603,6 @@ Use `BOUNDARY' to separate files. This is sent to Khoj server as a POST request.
;; -------------- ;; --------------
;; Query Khoj API ;; Query Khoj API
;; -------------- ;; --------------
(defun khoj--post-new-config (config)
"Configure khoj server with provided CONFIG."
;; POST provided config to khoj server
(let ((url-request-method "POST")
(url-request-extra-headers `(("Content-Type" . "application/json")
("Authorization" . ,(format "Bearer %s" khoj-api-key))))
(url-request-data (encode-coding-string (json-encode-alist config) 'utf-8))
(config-url (format "%s/api/config/data" khoj-server-url)))
(with-current-buffer (url-retrieve-synchronously config-url)
(buffer-string)))
;; Update index on khoj server after configuration update
(let ((khoj--server-ready? nil)
(url-request-extra-headers `(("Authorization" . ,(format "\"Bearer %s\"" khoj-api-key)))))
(url-retrieve (format "%s/api/update?client=emacs" khoj-server-url) #'identity)))
(defun khoj--get-enabled-content-types () (defun khoj--get-enabled-content-types ()
"Get content types enabled for search from API." "Get content types enabled for search from API."
(let ((config-url (format "%s/api/config/types" khoj-server-url)) (let ((config-url (format "%s/api/config/types" khoj-server-url))

View file

@ -245,7 +245,7 @@ export class KhojChatModal extends Modal {
if (responseJson.detail) { if (responseJson.detail) {
// If the server returns error details in response, render a setup hint. // If the server returns error details in response, render a setup hint.
let setupMsg = "Hi 👋🏾, to start chatting add available chat models options via [the Django Admin panel](/server/admin) on the Server"; let setupMsg = "Hi 👋🏾, to start chatting add available chat models options via [the Django Admin panel](/server/admin) on the Server";
this.renderMessage(chatBodyEl, setupMsg, "khoj", undefined, true); this.renderMessage(chatBodyEl, setupMsg, "khoj", undefined);
return false; return false;
} else if (responseJson.response) { } else if (responseJson.response) {

View file

@ -42,7 +42,7 @@ export class KhojSettingTab extends PluginSettingTab {
.addText(text => text .addText(text => text
.setValue(`${this.plugin.settings.khojUrl}`) .setValue(`${this.plugin.settings.khojUrl}`)
.onChange(async (value) => { .onChange(async (value) => {
this.plugin.settings.khojUrl = value.trim(); this.plugin.settings.khojUrl = value.trim().replace(/\/$/, '');
await this.plugin.saveSettings(); await this.plugin.saveSettings();
containerEl.firstElementChild?.setText(this.getBackendStatusMessage()); containerEl.firstElementChild?.setText(this.getBackendStatusMessage());
})); }));

View file

@ -1,12 +1,14 @@
import json import json
import logging import logging
import os import os
from datetime import datetime
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
import openai import openai
import requests import requests
import schedule import schedule
from django.utils.timezone import make_aware
from starlette.authentication import ( from starlette.authentication import (
AuthCredentials, AuthCredentials,
AuthenticationBackend, AuthenticationBackend,
@ -59,7 +61,8 @@ class UserAuthenticationBackend(AuthenticationBackend):
email="default@example.com", email="default@example.com",
password="default", password="default",
) )
Subscription.objects.create(user=default_user, type="standard", renewal_date="2100-04-01") renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
Subscription.objects.create(user=default_user, type="standard", renewal_date=renewal_date)
async def authenticate(self, request: HTTPConnection): async def authenticate(self, request: HTTPConnection):
current_user = request.session.get("user") current_user = request.session.get("user")
@ -190,14 +193,14 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
def configure_routes(app): def configure_routes(app):
# Import APIs here to setup search types before while configuring server # Import APIs here to setup search types before while configuring server
from khoj.routers.api import api from khoj.routers.api import api
from khoj.routers.api_beta import api_beta from khoj.routers.api_config import api_config
from khoj.routers.auth import auth_router from khoj.routers.auth import auth_router
from khoj.routers.indexer import indexer from khoj.routers.indexer import indexer
from khoj.routers.subscription import subscription_router from khoj.routers.subscription import subscription_router
from khoj.routers.web_client import web_client from khoj.routers.web_client import web_client
app.include_router(api, prefix="/api") app.include_router(api, prefix="/api")
app.include_router(api_beta, prefix="/api/beta") app.include_router(api_config, prefix="/api/config")
app.include_router(indexer, prefix="/api/v1/index") app.include_router(indexer, prefix="/api/v1/index")
if state.billing_enabled: if state.billing_enabled:
logger.info("💳 Enabled Billing") logger.info("💳 Enabled Billing")

View file

@ -15,24 +15,12 @@ from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires from starlette.authentication import requires
from khoj.configure import configure_server from khoj.configure import configure_server
from khoj.database import adapters
from khoj.database.adapters import ( from khoj.database.adapters import (
ConversationAdapters, ConversationAdapters,
EntryAdapters, EntryAdapters,
get_user_search_model_or_default, get_user_search_model_or_default,
) )
from khoj.database.models import ChatModelOptions from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOptions
from khoj.database.models import Entry as DbEntry
from khoj.database.models import (
GithubConfig,
KhojUser,
LocalMarkdownConfig,
LocalOrgConfig,
LocalPdfConfig,
LocalPlaintextConfig,
NotionConfig,
SpeechToTextModelOptions,
)
from khoj.processor.conversation.offline.chat_model import extract_questions_offline from khoj.processor.conversation.offline.chat_model import extract_questions_offline
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.gpt import extract_questions
@ -56,7 +44,7 @@ from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.word_filter import WordFilter
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import GPT4AllProcessorModel, TextSearchModel from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ( from khoj.utils.helpers import (
AsyncIteratorWrapper, AsyncIteratorWrapper,
ConversationCommand, ConversationCommand,
@ -65,13 +53,7 @@ from khoj.utils.helpers import (
is_none_or_empty, is_none_or_empty,
timer, timer,
) )
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import SearchResponse
FullConfig,
GithubContentConfig,
NotionContentConfig,
SearchConfig,
SearchResponse,
)
from khoj.utils.state import SearchType from khoj.utils.state import SearchType
# Initialize Router # Initialize Router
@ -80,330 +62,6 @@ logger = logging.getLogger(__name__)
conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100) conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100)
def map_config_to_object(content_source: str):
if content_source == DbEntry.EntrySource.GITHUB:
return GithubConfig
if content_source == DbEntry.EntrySource.GITHUB:
return NotionConfig
if content_source == DbEntry.EntrySource.COMPUTER:
return "Computer"
async def map_config_to_db(config: FullConfig, user: KhojUser):
if config.content_type:
if config.content_type.org:
await LocalOrgConfig.objects.filter(user=user).adelete()
await LocalOrgConfig.objects.acreate(
input_files=config.content_type.org.input_files,
input_filter=config.content_type.org.input_filter,
index_heading_entries=config.content_type.org.index_heading_entries,
user=user,
)
if config.content_type.markdown:
await LocalMarkdownConfig.objects.filter(user=user).adelete()
await LocalMarkdownConfig.objects.acreate(
input_files=config.content_type.markdown.input_files,
input_filter=config.content_type.markdown.input_filter,
index_heading_entries=config.content_type.markdown.index_heading_entries,
user=user,
)
if config.content_type.pdf:
await LocalPdfConfig.objects.filter(user=user).adelete()
await LocalPdfConfig.objects.acreate(
input_files=config.content_type.pdf.input_files,
input_filter=config.content_type.pdf.input_filter,
index_heading_entries=config.content_type.pdf.index_heading_entries,
user=user,
)
if config.content_type.plaintext:
await LocalPlaintextConfig.objects.filter(user=user).adelete()
await LocalPlaintextConfig.objects.acreate(
input_files=config.content_type.plaintext.input_files,
input_filter=config.content_type.plaintext.input_filter,
index_heading_entries=config.content_type.plaintext.index_heading_entries,
user=user,
)
if config.content_type.github:
await adapters.set_user_github_config(
user=user,
pat_token=config.content_type.github.pat_token,
repos=config.content_type.github.repos,
)
if config.content_type.notion:
await adapters.set_notion_config(
user=user,
token=config.content_type.notion.token,
)
def _initialize_config():
if state.config is None:
state.config = FullConfig()
state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
@api.get("/config/data", response_model=FullConfig)
@requires(["authenticated"])
def get_config_data(request: Request):
user = request.user.object
EntryAdapters.get_unique_file_types(user)
return state.config
@api.post("/config/data")
@requires(["authenticated"])
async def set_config_data(
request: Request,
updated_config: FullConfig,
client: Optional[str] = None,
):
user = request.user.object
await map_config_to_db(updated_config, user)
configuration_update_metadata = {}
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
if state.config.content_type is not None:
configuration_update_metadata["github"] = "github" in enabled_content
configuration_update_metadata["notion"] = "notion" in enabled_content
configuration_update_metadata["org"] = "org" in enabled_content
configuration_update_metadata["pdf"] = "pdf" in enabled_content
configuration_update_metadata["markdown"] = "markdown" in enabled_content
if state.config.processor is not None:
configuration_update_metadata["conversation_processor"] = state.config.processor.conversation is not None
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_config",
client=client,
metadata=configuration_update_metadata,
)
return state.config
@api.post("/config/data/content-source/github", status_code=200)
@requires(["authenticated"])
async def set_content_config_github_data(
request: Request,
updated_config: Union[GithubContentConfig, None],
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
try:
await adapters.set_user_github_config(
user=user,
pat_token=updated_config.pat_token,
repos=updated_config.repos,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_content_config",
client=client,
metadata={"content_type": "github"},
)
return {"status": "ok"}
@api.post("/config/data/content-source/notion", status_code=200)
@requires(["authenticated"])
async def set_content_config_notion_data(
request: Request,
updated_config: Union[NotionContentConfig, None],
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
try:
await adapters.set_notion_config(
user=user,
token=updated_config.token,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_content_config",
client=client,
metadata={"content_type": "notion"},
)
return {"status": "ok"}
@api.delete("/config/data/content-source/{content_source}", status_code=200)
@requires(["authenticated"])
async def remove_content_source_data(
request: Request,
content_source: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="delete_content_config",
client=client,
metadata={"content_source": content_source},
)
content_object = map_config_to_object(content_source)
if content_object is None:
raise ValueError(f"Invalid content source: {content_source}")
elif content_object != "Computer":
await content_object.objects.filter(user=user).adelete()
await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source)
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
return {"status": "ok"}
@api.delete("/config/data/file", status_code=200)
@requires(["authenticated"])
async def remove_file_data(
request: Request,
filename: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="delete_file",
client=client,
)
await EntryAdapters.adelete_entry_by_file(user, filename)
return {"status": "ok"}
@api.get("/config/data/{content_source}", response_model=List[str])
@requires(["authenticated"])
async def get_all_filenames(
request: Request,
content_source: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="get_all_filenames",
client=client,
)
return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
@api.post("/config/data/conversation/model", status_code=200)
@requires(["authenticated"])
async def update_chat_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_conversation_chat_model",
client=client,
metadata={"processor_conversation_type": "conversation"},
)
if new_config is None:
return {"status": "error", "message": "Model not found"}
return {"status": "ok"}
@api.post("/config/data/search/model", status_code=200)
@requires(["authenticated"])
async def update_search_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
new_config = await adapters.aset_user_search_model(user, int(id))
if new_config is None:
return {"status": "error", "message": "Model not found"}
else:
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_search_model",
client=client,
metadata={"search_model": new_config.setting.name},
)
return {"status": "ok"}
# Create Routes
@api.get("/config/data/default")
def get_default_config_data():
return constants.empty_config
@api.get("/config/index/size", response_model=Dict[str, int])
@requires(["authenticated"])
async def get_indexed_data_size(request: Request, common: CommonQueryParams):
user = request.user.object
indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
return Response(
content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
media_type="application/json",
status_code=200,
)
@api.get("/config/types", response_model=List[str])
@requires(["authenticated"])
def get_config_types(
request: Request,
):
user = request.user.object
enabled_file_types = EntryAdapters.get_unique_file_types(user)
configured_content_types = list(enabled_file_types)
if state.config and state.config.content_type:
for ctype in state.config.content_type.dict(exclude_none=True):
configured_content_types.append(ctype)
return [
search_type.value
for search_type in SearchType
if (search_type.value in configured_content_types) or search_type == SearchType.All
]
@api.get("/search", response_model=List[SearchResponse]) @api.get("/search", response_model=List[SearchResponse])
@requires(["authenticated"]) @requires(["authenticated"])
async def search( async def search(

View file

@ -1,7 +0,0 @@
import logging
from fastapi import APIRouter
# Initialize Router
api_beta = APIRouter()
logger = logging.getLogger(__name__)

View file

@ -0,0 +1,311 @@
import json
import logging
import math
from typing import Dict, List, Optional, Union
from asgiref.sync import sync_to_async
from fastapi import APIRouter, HTTPException, Request
from fastapi.requests import Request
from fastapi.responses import Response
from starlette.authentication import requires
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import Entry as DbEntry
from khoj.database.models import (
GithubConfig,
KhojUser,
LocalMarkdownConfig,
LocalOrgConfig,
LocalPdfConfig,
LocalPlaintextConfig,
NotionConfig,
)
from khoj.routers.helpers import CommonQueryParams, update_telemetry_state
from khoj.utils import constants, state
from khoj.utils.rawconfig import (
FullConfig,
GithubContentConfig,
NotionContentConfig,
SearchConfig,
)
from khoj.utils.state import SearchType
api_config = APIRouter()
logger = logging.getLogger(__name__)
def map_config_to_object(content_source: str):
if content_source == DbEntry.EntrySource.GITHUB:
return GithubConfig
if content_source == DbEntry.EntrySource.GITHUB:
return NotionConfig
if content_source == DbEntry.EntrySource.COMPUTER:
return "Computer"
async def map_config_to_db(config: FullConfig, user: KhojUser):
if config.content_type:
if config.content_type.org:
await LocalOrgConfig.objects.filter(user=user).adelete()
await LocalOrgConfig.objects.acreate(
input_files=config.content_type.org.input_files,
input_filter=config.content_type.org.input_filter,
index_heading_entries=config.content_type.org.index_heading_entries,
user=user,
)
if config.content_type.markdown:
await LocalMarkdownConfig.objects.filter(user=user).adelete()
await LocalMarkdownConfig.objects.acreate(
input_files=config.content_type.markdown.input_files,
input_filter=config.content_type.markdown.input_filter,
index_heading_entries=config.content_type.markdown.index_heading_entries,
user=user,
)
if config.content_type.pdf:
await LocalPdfConfig.objects.filter(user=user).adelete()
await LocalPdfConfig.objects.acreate(
input_files=config.content_type.pdf.input_files,
input_filter=config.content_type.pdf.input_filter,
index_heading_entries=config.content_type.pdf.index_heading_entries,
user=user,
)
if config.content_type.plaintext:
await LocalPlaintextConfig.objects.filter(user=user).adelete()
await LocalPlaintextConfig.objects.acreate(
input_files=config.content_type.plaintext.input_files,
input_filter=config.content_type.plaintext.input_filter,
index_heading_entries=config.content_type.plaintext.index_heading_entries,
user=user,
)
if config.content_type.github:
await adapters.set_user_github_config(
user=user,
pat_token=config.content_type.github.pat_token,
repos=config.content_type.github.repos,
)
if config.content_type.notion:
await adapters.set_notion_config(
user=user,
token=config.content_type.notion.token,
)
def _initialize_config():
if state.config is None:
state.config = FullConfig()
state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
@api_config.post("/data/content-source/github", status_code=200)
@requires(["authenticated"])
async def set_content_config_github_data(
request: Request,
updated_config: Union[GithubContentConfig, None],
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
try:
await adapters.set_user_github_config(
user=user,
pat_token=updated_config.pat_token,
repos=updated_config.repos,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_content_config",
client=client,
metadata={"content_type": "github"},
)
return {"status": "ok"}
@api_config.post("/data/content-source/notion", status_code=200)
@requires(["authenticated"])
async def set_content_config_notion_data(
request: Request,
updated_config: Union[NotionContentConfig, None],
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
try:
await adapters.set_notion_config(
user=user,
token=updated_config.token,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_content_config",
client=client,
metadata={"content_type": "notion"},
)
return {"status": "ok"}
@api_config.delete("/data/content-source/{content_source}", status_code=200)
@requires(["authenticated"])
async def remove_content_source_data(
request: Request,
content_source: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="delete_content_config",
client=client,
metadata={"content_source": content_source},
)
content_object = map_config_to_object(content_source)
if content_object is None:
raise ValueError(f"Invalid content source: {content_source}")
elif content_object != "Computer":
await content_object.objects.filter(user=user).adelete()
await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source)
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
return {"status": "ok"}
@api_config.delete("/data/file", status_code=200)
@requires(["authenticated"])
async def remove_file_data(
request: Request,
filename: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="delete_file",
client=client,
)
await EntryAdapters.adelete_entry_by_file(user, filename)
return {"status": "ok"}
@api_config.get("/data/{content_source}", response_model=List[str])
@requires(["authenticated"])
async def get_all_filenames(
request: Request,
content_source: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="get_all_filenames",
client=client,
)
return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
@api_config.post("/data/conversation/model", status_code=200)
@requires(["authenticated"])
async def update_chat_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_conversation_chat_model",
client=client,
metadata={"processor_conversation_type": "conversation"},
)
if new_config is None:
return {"status": "error", "message": "Model not found"}
return {"status": "ok"}
@api_config.post("/data/search/model", status_code=200)
@requires(["authenticated"])
async def update_search_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
new_config = await adapters.aset_user_search_model(user, int(id))
if new_config is None:
return {"status": "error", "message": "Model not found"}
else:
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_search_model",
client=client,
metadata={"search_model": new_config.setting.name},
)
return {"status": "ok"}
# Create Routes
@api_config.get("/index/size", response_model=Dict[str, int])
@requires(["authenticated"])
async def get_indexed_data_size(request: Request, common: CommonQueryParams):
user = request.user.object
indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
return Response(
content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
media_type="application/json",
status_code=200,
)
@api_config.get("/types", response_model=List[str])
@requires(["authenticated"])
def get_config_types(
request: Request,
):
user = request.user.object
enabled_file_types = EntryAdapters.get_unique_file_types(user)
configured_content_types = list(enabled_file_types)
if state.config and state.config.content_type:
for ctype in state.config.content_type.model_dump(exclude_none=True):
configured_content_types.append(ctype)
return [
search_type.value
for search_type in SearchType
if (search_type.value in configured_content_types) or search_type == SearchType.All
]

View file

@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
auth_router = APIRouter() auth_router = APIRouter()
if not state.anonymous_mode and not (os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET")): if not state.anonymous_mode and not (os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET")):
logger.warn( logger.warning(
"🚨 Use --anonymous-mode flag to disable Google OAuth or set GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET environment variables to enable it" "🚨 Use --anonymous-mode flag to disable Google OAuth or set GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET environment variables to enable it"
) )
else: else:

View file

@ -299,6 +299,11 @@ class ApiUserRateLimiter:
self.cache: dict[str, list[float]] = defaultdict(list) self.cache: dict[str, list[float]] = defaultdict(list)
def __call__(self, request: Request): def __call__(self, request: Request):
# Rate limiting is disabled if user unauthenticated.
# Other systems handle authentication
if not request.user.is_authenticated:
return
user: KhojUser = request.user.object user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"]) subscribed = has_required_scope(request, ["premium"])
user_requests = self.cache[user.uuid] user_requests = self.cache[user.uuid]

View file

@ -37,7 +37,7 @@ async def subscribe(request: Request):
"customer.subscription.updated", "customer.subscription.updated",
"customer.subscription.deleted", "customer.subscription.deleted",
}: }:
logger.warn(f"Unhandled Stripe event type: {event['type']}") logger.warning(f"Unhandled Stripe event type: {event['type']}")
return {"success": False} return {"success": False}
# Retrieve the customer's details # Retrieve the customer's details

View file

@ -254,51 +254,45 @@ def md_content_config():
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def chat_client(search_config: SearchConfig, default_user2: KhojUser): def chat_client(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state return chat_client_builder(search_config, default_user2, require_auth=False)
state.config.search_type = search_config
state.SearchType = configure_search_types()
LocalMarkdownConfig.objects.create(
input_files=None,
input_filter=["tests/data/markdown/*.markdown"],
user=default_user2,
)
# Index Markdown Content for Search @pytest.fixture(scope="function")
all_files = fs_syncer.collect_files(user=default_user2) def chat_client_with_auth(search_config: SearchConfig, default_user2: KhojUser):
state.content_index, _ = configure_content( return chat_client_builder(search_config, default_user2, require_auth=True)
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
)
# Initialize Processor from Config
if os.getenv("OPENAI_API_KEY"):
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
OpenAIProcessorConversationConfigFactory()
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
state.anonymous_mode = True
app = FastAPI()
configure_routes(app)
configure_middleware(app)
app.mount("/static", StaticFiles(directory=web_directory), name="static")
return TestClient(app)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser): def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
return chat_client_builder(search_config, default_user2, index_content=False, require_auth=False)
@pytest.mark.django_db
def chat_client_builder(search_config, user, index_content=True, require_auth=False):
# Initialize app state # Initialize app state
state.config.search_type = search_config state.config.search_type = search_config
state.SearchType = configure_search_types() state.SearchType = configure_search_types()
if index_content:
LocalMarkdownConfig.objects.create(
input_files=None,
input_filter=["tests/data/markdown/*.markdown"],
user=user,
)
# Index Markdown Content for Search
all_files = fs_syncer.collect_files(user=user)
state.content_index, _ = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models, user=user
)
# Initialize Processor from Config # Initialize Processor from Config
if os.getenv("OPENAI_API_KEY"): if os.getenv("OPENAI_API_KEY"):
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai") chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
OpenAIProcessorConversationConfigFactory() OpenAIProcessorConversationConfigFactory()
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model) UserConversationProcessorConfigFactory(user=user, setting=chat_model)
state.anonymous_mode = True state.anonymous_mode = not require_auth
app = FastAPI() app = FastAPI()

View file

@ -1,6 +1,8 @@
import os import os
from datetime import datetime
import factory import factory
from django.utils.timezone import make_aware
from khoj.database.models import ( from khoj.database.models import (
ChatModelOptions, ChatModelOptions,
@ -90,4 +92,4 @@ class SubscriptionFactory(factory.django.DjangoModelFactory):
user = factory.SubFactory(UserFactory) user = factory.SubFactory(UserFactory)
type = "standard" type = "standard"
is_recurring = False is_recurring = False
renewal_date = "2100-04-01" renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))

View file

@ -1,4 +1,5 @@
# Standard Modules # Standard Modules
import os
from io import BytesIO from io import BytesIO
from urllib.parse import quote from urllib.parse import quote
@ -482,6 +483,21 @@ def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiU
assert response.json() == [] assert response.json() == []
@pytest.mark.skipif(os.getenv("OPENAI_API_KEY") is None, reason="requires OPENAI_API_KEY")
@pytest.mark.django_db(transaction=True)
def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser):
# Arrange
headers = {"Authorization": f"Bearer {api_user2.token}"}
# Act
auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true', headers=headers)
no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true')
# Assert
assert auth_response.status_code == 200
assert no_auth_response.status_code == 403
def get_sample_files_data(): def get_sample_files_data():
return [ return [
("files", ("path/to/filename.org", "* practicing piano", "text/org")), ("files", ("path/to/filename.org", "* practicing piano", "text/org")),

View file

@ -53,6 +53,7 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("SERPER_DEV_API_KEY") is None, reason="requires SERPER_DEV_API_KEY")
@pytest.mark.chatquality @pytest.mark.chatquality
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_chat_with_online_content(chat_client): def test_chat_with_online_content(chat_client):
@ -65,7 +66,7 @@ def test_chat_with_online_content(chat_client):
response_message = response_message.split("### compiled references")[0] response_message = response_message.split("### compiled references")[0]
# Assert # Assert
expected_responses = ["http://www.paulgraham.com/greatwork.html", "Please set your SERPER_DEV_API_KEY"] expected_responses = ["http://www.paulgraham.com/greatwork.html"]
assert response.status_code == 200 assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), ( assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected links or serper not setup in response but got: " + response_message "Expected links or serper not setup in response but got: " + response_message