mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Short-circuit API rate limiter for unauthenticated users (#607)
### Major - Short-circuit API rate limiter for unauthenticated user Calls by unauthenticated users were failing at API rate limiter as it failed to access user info object. This is a bug. API rate limiter should short-circuit for unauthenicated users so a proper Forbidden response can be returned by API Add regression test to verify that unauthenticated users get 403 response when calling the /chat API endpoint ### Minor - Remove trailing slash to normalize khoj url in obsidian plugin settings - Move used /api/config API controllers into separate module - Delete unused /api/beta API endpoint - Fix error message rendering in khoj.el, khoj obsidian chat - Handle deprecation warnings for subscribe renew date, langchain, pydantic & logger.warn
This commit is contained in:
commit
4d30f7d1d9
15 changed files with 377 additions and 410 deletions
|
@ -55,7 +55,7 @@ dependencies = [
|
||||||
"torch == 2.0.1",
|
"torch == 2.0.1",
|
||||||
"uvicorn == 0.17.6",
|
"uvicorn == 0.17.6",
|
||||||
"aiohttp ~= 3.9.0",
|
"aiohttp ~= 3.9.0",
|
||||||
"langchain >= 0.0.331",
|
"langchain <= 0.2.0",
|
||||||
"requests >= 2.26.0",
|
"requests >= 2.26.0",
|
||||||
"bs4 >= 0.0.1",
|
"bs4 >= 0.0.1",
|
||||||
"anyio == 3.7.1",
|
"anyio == 3.7.1",
|
||||||
|
|
|
@ -348,7 +348,7 @@ Auto invokes setup steps on calling main entrypoint."
|
||||||
t
|
t
|
||||||
;; else general check via ping to khoj-server-url
|
;; else general check via ping to khoj-server-url
|
||||||
(if (ignore-errors
|
(if (ignore-errors
|
||||||
(url-retrieve-synchronously (format "%s/api/config/data/default" khoj-server-url)))
|
(url-retrieve-synchronously (format "%s/api/health" khoj-server-url)))
|
||||||
;; Successful ping to non-emacs khoj server indicates it is started and ready.
|
;; Successful ping to non-emacs khoj server indicates it is started and ready.
|
||||||
;; So update ready state tracker variable (and implicitly return true for started)
|
;; So update ready state tracker variable (and implicitly return true for started)
|
||||||
(setq khoj--server-ready? t)
|
(setq khoj--server-ready? t)
|
||||||
|
@ -432,7 +432,7 @@ Auto invokes setup steps on calling main entrypoint."
|
||||||
(khoj--delete-open-network-connections-to-server)
|
(khoj--delete-open-network-connections-to-server)
|
||||||
(with-current-buffer (current-buffer)
|
(with-current-buffer (current-buffer)
|
||||||
(search-forward "\n\n" nil t)
|
(search-forward "\n\n" nil t)
|
||||||
(message "khoj.el: Failed to %supdate %s content index. Status: %s%s"
|
(message "khoj.el: Failed to %supdate %scontent index. Status: %s%s"
|
||||||
(if force "force " "")
|
(if force "force " "")
|
||||||
(if content-type (format "%s " content-type) "all")
|
(if content-type (format "%s " content-type) "all")
|
||||||
(string-trim (format "%s %s" (nth 1 (nth 1 status)) (nth 2 (nth 1 status))))
|
(string-trim (format "%s %s" (nth 1 (nth 1 status)) (nth 2 (nth 1 status))))
|
||||||
|
@ -603,22 +603,6 @@ Use `BOUNDARY' to separate files. This is sent to Khoj server as a POST request.
|
||||||
;; --------------
|
;; --------------
|
||||||
;; Query Khoj API
|
;; Query Khoj API
|
||||||
;; --------------
|
;; --------------
|
||||||
|
|
||||||
(defun khoj--post-new-config (config)
|
|
||||||
"Configure khoj server with provided CONFIG."
|
|
||||||
;; POST provided config to khoj server
|
|
||||||
(let ((url-request-method "POST")
|
|
||||||
(url-request-extra-headers `(("Content-Type" . "application/json")
|
|
||||||
("Authorization" . ,(format "Bearer %s" khoj-api-key))))
|
|
||||||
(url-request-data (encode-coding-string (json-encode-alist config) 'utf-8))
|
|
||||||
(config-url (format "%s/api/config/data" khoj-server-url)))
|
|
||||||
(with-current-buffer (url-retrieve-synchronously config-url)
|
|
||||||
(buffer-string)))
|
|
||||||
;; Update index on khoj server after configuration update
|
|
||||||
(let ((khoj--server-ready? nil)
|
|
||||||
(url-request-extra-headers `(("Authorization" . ,(format "\"Bearer %s\"" khoj-api-key)))))
|
|
||||||
(url-retrieve (format "%s/api/update?client=emacs" khoj-server-url) #'identity)))
|
|
||||||
|
|
||||||
(defun khoj--get-enabled-content-types ()
|
(defun khoj--get-enabled-content-types ()
|
||||||
"Get content types enabled for search from API."
|
"Get content types enabled for search from API."
|
||||||
(let ((config-url (format "%s/api/config/types" khoj-server-url))
|
(let ((config-url (format "%s/api/config/types" khoj-server-url))
|
||||||
|
|
|
@ -245,7 +245,7 @@ export class KhojChatModal extends Modal {
|
||||||
if (responseJson.detail) {
|
if (responseJson.detail) {
|
||||||
// If the server returns error details in response, render a setup hint.
|
// If the server returns error details in response, render a setup hint.
|
||||||
let setupMsg = "Hi 👋🏾, to start chatting add available chat models options via [the Django Admin panel](/server/admin) on the Server";
|
let setupMsg = "Hi 👋🏾, to start chatting add available chat models options via [the Django Admin panel](/server/admin) on the Server";
|
||||||
this.renderMessage(chatBodyEl, setupMsg, "khoj", undefined, true);
|
this.renderMessage(chatBodyEl, setupMsg, "khoj", undefined);
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
} else if (responseJson.response) {
|
} else if (responseJson.response) {
|
||||||
|
|
|
@ -42,7 +42,7 @@ export class KhojSettingTab extends PluginSettingTab {
|
||||||
.addText(text => text
|
.addText(text => text
|
||||||
.setValue(`${this.plugin.settings.khojUrl}`)
|
.setValue(`${this.plugin.settings.khojUrl}`)
|
||||||
.onChange(async (value) => {
|
.onChange(async (value) => {
|
||||||
this.plugin.settings.khojUrl = value.trim();
|
this.plugin.settings.khojUrl = value.trim().replace(/\/$/, '');
|
||||||
await this.plugin.saveSettings();
|
await this.plugin.saveSettings();
|
||||||
containerEl.firstElementChild?.setText(this.getBackendStatusMessage());
|
containerEl.firstElementChild?.setText(this.getBackendStatusMessage());
|
||||||
}));
|
}));
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
import schedule
|
import schedule
|
||||||
|
from django.utils.timezone import make_aware
|
||||||
from starlette.authentication import (
|
from starlette.authentication import (
|
||||||
AuthCredentials,
|
AuthCredentials,
|
||||||
AuthenticationBackend,
|
AuthenticationBackend,
|
||||||
|
@ -59,7 +61,8 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||||
email="default@example.com",
|
email="default@example.com",
|
||||||
password="default",
|
password="default",
|
||||||
)
|
)
|
||||||
Subscription.objects.create(user=default_user, type="standard", renewal_date="2100-04-01")
|
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
|
||||||
|
Subscription.objects.create(user=default_user, type="standard", renewal_date=renewal_date)
|
||||||
|
|
||||||
async def authenticate(self, request: HTTPConnection):
|
async def authenticate(self, request: HTTPConnection):
|
||||||
current_user = request.session.get("user")
|
current_user = request.session.get("user")
|
||||||
|
@ -190,14 +193,14 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
|
||||||
def configure_routes(app):
|
def configure_routes(app):
|
||||||
# Import APIs here to setup search types before while configuring server
|
# Import APIs here to setup search types before while configuring server
|
||||||
from khoj.routers.api import api
|
from khoj.routers.api import api
|
||||||
from khoj.routers.api_beta import api_beta
|
from khoj.routers.api_config import api_config
|
||||||
from khoj.routers.auth import auth_router
|
from khoj.routers.auth import auth_router
|
||||||
from khoj.routers.indexer import indexer
|
from khoj.routers.indexer import indexer
|
||||||
from khoj.routers.subscription import subscription_router
|
from khoj.routers.subscription import subscription_router
|
||||||
from khoj.routers.web_client import web_client
|
from khoj.routers.web_client import web_client
|
||||||
|
|
||||||
app.include_router(api, prefix="/api")
|
app.include_router(api, prefix="/api")
|
||||||
app.include_router(api_beta, prefix="/api/beta")
|
app.include_router(api_config, prefix="/api/config")
|
||||||
app.include_router(indexer, prefix="/api/v1/index")
|
app.include_router(indexer, prefix="/api/v1/index")
|
||||||
if state.billing_enabled:
|
if state.billing_enabled:
|
||||||
logger.info("💳 Enabled Billing")
|
logger.info("💳 Enabled Billing")
|
||||||
|
|
|
@ -15,24 +15,12 @@ from fastapi.responses import Response, StreamingResponse
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
|
|
||||||
from khoj.configure import configure_server
|
from khoj.configure import configure_server
|
||||||
from khoj.database import adapters
|
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
get_user_search_model_or_default,
|
get_user_search_model_or_default,
|
||||||
)
|
)
|
||||||
from khoj.database.models import ChatModelOptions
|
from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOptions
|
||||||
from khoj.database.models import Entry as DbEntry
|
|
||||||
from khoj.database.models import (
|
|
||||||
GithubConfig,
|
|
||||||
KhojUser,
|
|
||||||
LocalMarkdownConfig,
|
|
||||||
LocalOrgConfig,
|
|
||||||
LocalPdfConfig,
|
|
||||||
LocalPlaintextConfig,
|
|
||||||
NotionConfig,
|
|
||||||
SpeechToTextModelOptions,
|
|
||||||
)
|
|
||||||
from khoj.processor.conversation.offline.chat_model import extract_questions_offline
|
from khoj.processor.conversation.offline.chat_model import extract_questions_offline
|
||||||
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
||||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||||
|
@ -56,7 +44,7 @@ from khoj.search_filter.file_filter import FileFilter
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
from khoj.utils.config import GPT4AllProcessorModel, TextSearchModel
|
from khoj.utils.config import GPT4AllProcessorModel
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
AsyncIteratorWrapper,
|
AsyncIteratorWrapper,
|
||||||
ConversationCommand,
|
ConversationCommand,
|
||||||
|
@ -65,13 +53,7 @@ from khoj.utils.helpers import (
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
timer,
|
timer,
|
||||||
)
|
)
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import SearchResponse
|
||||||
FullConfig,
|
|
||||||
GithubContentConfig,
|
|
||||||
NotionContentConfig,
|
|
||||||
SearchConfig,
|
|
||||||
SearchResponse,
|
|
||||||
)
|
|
||||||
from khoj.utils.state import SearchType
|
from khoj.utils.state import SearchType
|
||||||
|
|
||||||
# Initialize Router
|
# Initialize Router
|
||||||
|
@ -80,330 +62,6 @@ logger = logging.getLogger(__name__)
|
||||||
conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100)
|
conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100)
|
||||||
|
|
||||||
|
|
||||||
def map_config_to_object(content_source: str):
|
|
||||||
if content_source == DbEntry.EntrySource.GITHUB:
|
|
||||||
return GithubConfig
|
|
||||||
if content_source == DbEntry.EntrySource.GITHUB:
|
|
||||||
return NotionConfig
|
|
||||||
if content_source == DbEntry.EntrySource.COMPUTER:
|
|
||||||
return "Computer"
|
|
||||||
|
|
||||||
|
|
||||||
async def map_config_to_db(config: FullConfig, user: KhojUser):
|
|
||||||
if config.content_type:
|
|
||||||
if config.content_type.org:
|
|
||||||
await LocalOrgConfig.objects.filter(user=user).adelete()
|
|
||||||
await LocalOrgConfig.objects.acreate(
|
|
||||||
input_files=config.content_type.org.input_files,
|
|
||||||
input_filter=config.content_type.org.input_filter,
|
|
||||||
index_heading_entries=config.content_type.org.index_heading_entries,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
if config.content_type.markdown:
|
|
||||||
await LocalMarkdownConfig.objects.filter(user=user).adelete()
|
|
||||||
await LocalMarkdownConfig.objects.acreate(
|
|
||||||
input_files=config.content_type.markdown.input_files,
|
|
||||||
input_filter=config.content_type.markdown.input_filter,
|
|
||||||
index_heading_entries=config.content_type.markdown.index_heading_entries,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
if config.content_type.pdf:
|
|
||||||
await LocalPdfConfig.objects.filter(user=user).adelete()
|
|
||||||
await LocalPdfConfig.objects.acreate(
|
|
||||||
input_files=config.content_type.pdf.input_files,
|
|
||||||
input_filter=config.content_type.pdf.input_filter,
|
|
||||||
index_heading_entries=config.content_type.pdf.index_heading_entries,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
if config.content_type.plaintext:
|
|
||||||
await LocalPlaintextConfig.objects.filter(user=user).adelete()
|
|
||||||
await LocalPlaintextConfig.objects.acreate(
|
|
||||||
input_files=config.content_type.plaintext.input_files,
|
|
||||||
input_filter=config.content_type.plaintext.input_filter,
|
|
||||||
index_heading_entries=config.content_type.plaintext.index_heading_entries,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
if config.content_type.github:
|
|
||||||
await adapters.set_user_github_config(
|
|
||||||
user=user,
|
|
||||||
pat_token=config.content_type.github.pat_token,
|
|
||||||
repos=config.content_type.github.repos,
|
|
||||||
)
|
|
||||||
if config.content_type.notion:
|
|
||||||
await adapters.set_notion_config(
|
|
||||||
user=user,
|
|
||||||
token=config.content_type.notion.token,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _initialize_config():
|
|
||||||
if state.config is None:
|
|
||||||
state.config = FullConfig()
|
|
||||||
state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
|
|
||||||
|
|
||||||
|
|
||||||
@api.get("/config/data", response_model=FullConfig)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
def get_config_data(request: Request):
|
|
||||||
user = request.user.object
|
|
||||||
EntryAdapters.get_unique_file_types(user)
|
|
||||||
|
|
||||||
return state.config
|
|
||||||
|
|
||||||
|
|
||||||
@api.post("/config/data")
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def set_config_data(
|
|
||||||
request: Request,
|
|
||||||
updated_config: FullConfig,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
await map_config_to_db(updated_config, user)
|
|
||||||
|
|
||||||
configuration_update_metadata = {}
|
|
||||||
|
|
||||||
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
|
|
||||||
|
|
||||||
if state.config.content_type is not None:
|
|
||||||
configuration_update_metadata["github"] = "github" in enabled_content
|
|
||||||
configuration_update_metadata["notion"] = "notion" in enabled_content
|
|
||||||
configuration_update_metadata["org"] = "org" in enabled_content
|
|
||||||
configuration_update_metadata["pdf"] = "pdf" in enabled_content
|
|
||||||
configuration_update_metadata["markdown"] = "markdown" in enabled_content
|
|
||||||
|
|
||||||
if state.config.processor is not None:
|
|
||||||
configuration_update_metadata["conversation_processor"] = state.config.processor.conversation is not None
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="set_config",
|
|
||||||
client=client,
|
|
||||||
metadata=configuration_update_metadata,
|
|
||||||
)
|
|
||||||
return state.config
|
|
||||||
|
|
||||||
|
|
||||||
@api.post("/config/data/content-source/github", status_code=200)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def set_content_config_github_data(
|
|
||||||
request: Request,
|
|
||||||
updated_config: Union[GithubContentConfig, None],
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
_initialize_config()
|
|
||||||
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
try:
|
|
||||||
await adapters.set_user_github_config(
|
|
||||||
user=user,
|
|
||||||
pat_token=updated_config.pat_token,
|
|
||||||
repos=updated_config.repos,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e, exc_info=True)
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="set_content_config",
|
|
||||||
client=client,
|
|
||||||
metadata={"content_type": "github"},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@api.post("/config/data/content-source/notion", status_code=200)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def set_content_config_notion_data(
|
|
||||||
request: Request,
|
|
||||||
updated_config: Union[NotionContentConfig, None],
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
_initialize_config()
|
|
||||||
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
try:
|
|
||||||
await adapters.set_notion_config(
|
|
||||||
user=user,
|
|
||||||
token=updated_config.token,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e, exc_info=True)
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="set_content_config",
|
|
||||||
client=client,
|
|
||||||
metadata={"content_type": "notion"},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@api.delete("/config/data/content-source/{content_source}", status_code=200)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def remove_content_source_data(
|
|
||||||
request: Request,
|
|
||||||
content_source: str,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="delete_content_config",
|
|
||||||
client=client,
|
|
||||||
metadata={"content_source": content_source},
|
|
||||||
)
|
|
||||||
|
|
||||||
content_object = map_config_to_object(content_source)
|
|
||||||
if content_object is None:
|
|
||||||
raise ValueError(f"Invalid content source: {content_source}")
|
|
||||||
elif content_object != "Computer":
|
|
||||||
await content_object.objects.filter(user=user).adelete()
|
|
||||||
await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source)
|
|
||||||
|
|
||||||
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@api.delete("/config/data/file", status_code=200)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def remove_file_data(
|
|
||||||
request: Request,
|
|
||||||
filename: str,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="delete_file",
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
await EntryAdapters.adelete_entry_by_file(user, filename)
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@api.get("/config/data/{content_source}", response_model=List[str])
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def get_all_filenames(
|
|
||||||
request: Request,
|
|
||||||
content_source: str,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="get_all_filenames",
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
|
|
||||||
|
|
||||||
|
|
||||||
@api.post("/config/data/conversation/model", status_code=200)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def update_chat_model(
|
|
||||||
request: Request,
|
|
||||||
id: str,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="set_conversation_chat_model",
|
|
||||||
client=client,
|
|
||||||
metadata={"processor_conversation_type": "conversation"},
|
|
||||||
)
|
|
||||||
|
|
||||||
if new_config is None:
|
|
||||||
return {"status": "error", "message": "Model not found"}
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@api.post("/config/data/search/model", status_code=200)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def update_search_model(
|
|
||||||
request: Request,
|
|
||||||
id: str,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
new_config = await adapters.aset_user_search_model(user, int(id))
|
|
||||||
|
|
||||||
if new_config is None:
|
|
||||||
return {"status": "error", "message": "Model not found"}
|
|
||||||
else:
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="set_search_model",
|
|
||||||
client=client,
|
|
||||||
metadata={"search_model": new_config.setting.name},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
# Create Routes
|
|
||||||
@api.get("/config/data/default")
|
|
||||||
def get_default_config_data():
|
|
||||||
return constants.empty_config
|
|
||||||
|
|
||||||
|
|
||||||
@api.get("/config/index/size", response_model=Dict[str, int])
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def get_indexed_data_size(request: Request, common: CommonQueryParams):
|
|
||||||
user = request.user.object
|
|
||||||
indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
|
|
||||||
return Response(
|
|
||||||
content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
|
|
||||||
media_type="application/json",
|
|
||||||
status_code=200,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@api.get("/config/types", response_model=List[str])
|
|
||||||
@requires(["authenticated"])
|
|
||||||
def get_config_types(
|
|
||||||
request: Request,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
enabled_file_types = EntryAdapters.get_unique_file_types(user)
|
|
||||||
configured_content_types = list(enabled_file_types)
|
|
||||||
|
|
||||||
if state.config and state.config.content_type:
|
|
||||||
for ctype in state.config.content_type.dict(exclude_none=True):
|
|
||||||
configured_content_types.append(ctype)
|
|
||||||
|
|
||||||
return [
|
|
||||||
search_type.value
|
|
||||||
for search_type in SearchType
|
|
||||||
if (search_type.value in configured_content_types) or search_type == SearchType.All
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@api.get("/search", response_model=List[SearchResponse])
|
@api.get("/search", response_model=List[SearchResponse])
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def search(
|
async def search(
|
||||||
|
|
|
@ -1,7 +0,0 @@
|
||||||
import logging
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
|
|
||||||
# Initialize Router
|
|
||||||
api_beta = APIRouter()
|
|
||||||
logger = logging.getLogger(__name__)
|
|
311
src/khoj/routers/api_config.py
Normal file
311
src/khoj/routers/api_config.py
Normal file
|
@ -0,0 +1,311 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.requests import Request
|
||||||
|
from fastapi.responses import Response
|
||||||
|
from starlette.authentication import requires
|
||||||
|
|
||||||
|
from khoj.database import adapters
|
||||||
|
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||||
|
from khoj.database.models import Entry as DbEntry
|
||||||
|
from khoj.database.models import (
|
||||||
|
GithubConfig,
|
||||||
|
KhojUser,
|
||||||
|
LocalMarkdownConfig,
|
||||||
|
LocalOrgConfig,
|
||||||
|
LocalPdfConfig,
|
||||||
|
LocalPlaintextConfig,
|
||||||
|
NotionConfig,
|
||||||
|
)
|
||||||
|
from khoj.routers.helpers import CommonQueryParams, update_telemetry_state
|
||||||
|
from khoj.utils import constants, state
|
||||||
|
from khoj.utils.rawconfig import (
|
||||||
|
FullConfig,
|
||||||
|
GithubContentConfig,
|
||||||
|
NotionContentConfig,
|
||||||
|
SearchConfig,
|
||||||
|
)
|
||||||
|
from khoj.utils.state import SearchType
|
||||||
|
|
||||||
|
api_config = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def map_config_to_object(content_source: str):
|
||||||
|
if content_source == DbEntry.EntrySource.GITHUB:
|
||||||
|
return GithubConfig
|
||||||
|
if content_source == DbEntry.EntrySource.GITHUB:
|
||||||
|
return NotionConfig
|
||||||
|
if content_source == DbEntry.EntrySource.COMPUTER:
|
||||||
|
return "Computer"
|
||||||
|
|
||||||
|
|
||||||
|
async def map_config_to_db(config: FullConfig, user: KhojUser):
|
||||||
|
if config.content_type:
|
||||||
|
if config.content_type.org:
|
||||||
|
await LocalOrgConfig.objects.filter(user=user).adelete()
|
||||||
|
await LocalOrgConfig.objects.acreate(
|
||||||
|
input_files=config.content_type.org.input_files,
|
||||||
|
input_filter=config.content_type.org.input_filter,
|
||||||
|
index_heading_entries=config.content_type.org.index_heading_entries,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
if config.content_type.markdown:
|
||||||
|
await LocalMarkdownConfig.objects.filter(user=user).adelete()
|
||||||
|
await LocalMarkdownConfig.objects.acreate(
|
||||||
|
input_files=config.content_type.markdown.input_files,
|
||||||
|
input_filter=config.content_type.markdown.input_filter,
|
||||||
|
index_heading_entries=config.content_type.markdown.index_heading_entries,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
if config.content_type.pdf:
|
||||||
|
await LocalPdfConfig.objects.filter(user=user).adelete()
|
||||||
|
await LocalPdfConfig.objects.acreate(
|
||||||
|
input_files=config.content_type.pdf.input_files,
|
||||||
|
input_filter=config.content_type.pdf.input_filter,
|
||||||
|
index_heading_entries=config.content_type.pdf.index_heading_entries,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
if config.content_type.plaintext:
|
||||||
|
await LocalPlaintextConfig.objects.filter(user=user).adelete()
|
||||||
|
await LocalPlaintextConfig.objects.acreate(
|
||||||
|
input_files=config.content_type.plaintext.input_files,
|
||||||
|
input_filter=config.content_type.plaintext.input_filter,
|
||||||
|
index_heading_entries=config.content_type.plaintext.index_heading_entries,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
if config.content_type.github:
|
||||||
|
await adapters.set_user_github_config(
|
||||||
|
user=user,
|
||||||
|
pat_token=config.content_type.github.pat_token,
|
||||||
|
repos=config.content_type.github.repos,
|
||||||
|
)
|
||||||
|
if config.content_type.notion:
|
||||||
|
await adapters.set_notion_config(
|
||||||
|
user=user,
|
||||||
|
token=config.content_type.notion.token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_config():
|
||||||
|
if state.config is None:
|
||||||
|
state.config = FullConfig()
|
||||||
|
state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
|
||||||
|
|
||||||
|
|
||||||
|
@api_config.post("/data/content-source/github", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def set_content_config_github_data(
|
||||||
|
request: Request,
|
||||||
|
updated_config: Union[GithubContentConfig, None],
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
_initialize_config()
|
||||||
|
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
try:
|
||||||
|
await adapters.set_user_github_config(
|
||||||
|
user=user,
|
||||||
|
pat_token=updated_config.pat_token,
|
||||||
|
repos=updated_config.repos,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="set_content_config",
|
||||||
|
client=client,
|
||||||
|
metadata={"content_type": "github"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@api_config.post("/data/content-source/notion", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def set_content_config_notion_data(
|
||||||
|
request: Request,
|
||||||
|
updated_config: Union[NotionContentConfig, None],
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
_initialize_config()
|
||||||
|
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
try:
|
||||||
|
await adapters.set_notion_config(
|
||||||
|
user=user,
|
||||||
|
token=updated_config.token,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="set_content_config",
|
||||||
|
client=client,
|
||||||
|
metadata={"content_type": "notion"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@api_config.delete("/data/content-source/{content_source}", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def remove_content_source_data(
|
||||||
|
request: Request,
|
||||||
|
content_source: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="delete_content_config",
|
||||||
|
client=client,
|
||||||
|
metadata={"content_source": content_source},
|
||||||
|
)
|
||||||
|
|
||||||
|
content_object = map_config_to_object(content_source)
|
||||||
|
if content_object is None:
|
||||||
|
raise ValueError(f"Invalid content source: {content_source}")
|
||||||
|
elif content_object != "Computer":
|
||||||
|
await content_object.objects.filter(user=user).adelete()
|
||||||
|
await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source)
|
||||||
|
|
||||||
|
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@api_config.delete("/data/file", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def remove_file_data(
|
||||||
|
request: Request,
|
||||||
|
filename: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="delete_file",
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
await EntryAdapters.adelete_entry_by_file(user, filename)
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@api_config.get("/data/{content_source}", response_model=List[str])
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def get_all_filenames(
|
||||||
|
request: Request,
|
||||||
|
content_source: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="get_all_filenames",
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
|
||||||
|
@api_config.post("/data/conversation/model", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def update_chat_model(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="set_conversation_chat_model",
|
||||||
|
client=client,
|
||||||
|
metadata={"processor_conversation_type": "conversation"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_config is None:
|
||||||
|
return {"status": "error", "message": "Model not found"}
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@api_config.post("/data/search/model", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def update_search_model(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
new_config = await adapters.aset_user_search_model(user, int(id))
|
||||||
|
|
||||||
|
if new_config is None:
|
||||||
|
return {"status": "error", "message": "Model not found"}
|
||||||
|
else:
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="set_search_model",
|
||||||
|
client=client,
|
||||||
|
metadata={"search_model": new_config.setting.name},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
# Create Routes
|
||||||
|
@api_config.get("/index/size", response_model=Dict[str, int])
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def get_indexed_data_size(request: Request, common: CommonQueryParams):
|
||||||
|
user = request.user.object
|
||||||
|
indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
|
||||||
|
return Response(
|
||||||
|
content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
|
||||||
|
media_type="application/json",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_config.get("/types", response_model=List[str])
|
||||||
|
@requires(["authenticated"])
|
||||||
|
def get_config_types(
|
||||||
|
request: Request,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
enabled_file_types = EntryAdapters.get_unique_file_types(user)
|
||||||
|
configured_content_types = list(enabled_file_types)
|
||||||
|
|
||||||
|
if state.config and state.config.content_type:
|
||||||
|
for ctype in state.config.content_type.model_dump(exclude_none=True):
|
||||||
|
configured_content_types.append(ctype)
|
||||||
|
|
||||||
|
return [
|
||||||
|
search_type.value
|
||||||
|
for search_type in SearchType
|
||||||
|
if (search_type.value in configured_content_types) or search_type == SearchType.All
|
||||||
|
]
|
|
@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
|
||||||
auth_router = APIRouter()
|
auth_router = APIRouter()
|
||||||
|
|
||||||
if not state.anonymous_mode and not (os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET")):
|
if not state.anonymous_mode and not (os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET")):
|
||||||
logger.warn(
|
logger.warning(
|
||||||
"🚨 Use --anonymous-mode flag to disable Google OAuth or set GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET environment variables to enable it"
|
"🚨 Use --anonymous-mode flag to disable Google OAuth or set GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET environment variables to enable it"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -299,6 +299,11 @@ class ApiUserRateLimiter:
|
||||||
self.cache: dict[str, list[float]] = defaultdict(list)
|
self.cache: dict[str, list[float]] = defaultdict(list)
|
||||||
|
|
||||||
def __call__(self, request: Request):
|
def __call__(self, request: Request):
|
||||||
|
# Rate limiting is disabled if user unauthenticated.
|
||||||
|
# Other systems handle authentication
|
||||||
|
if not request.user.is_authenticated:
|
||||||
|
return
|
||||||
|
|
||||||
user: KhojUser = request.user.object
|
user: KhojUser = request.user.object
|
||||||
subscribed = has_required_scope(request, ["premium"])
|
subscribed = has_required_scope(request, ["premium"])
|
||||||
user_requests = self.cache[user.uuid]
|
user_requests = self.cache[user.uuid]
|
||||||
|
|
|
@ -37,7 +37,7 @@ async def subscribe(request: Request):
|
||||||
"customer.subscription.updated",
|
"customer.subscription.updated",
|
||||||
"customer.subscription.deleted",
|
"customer.subscription.deleted",
|
||||||
}:
|
}:
|
||||||
logger.warn(f"Unhandled Stripe event type: {event['type']}")
|
logger.warning(f"Unhandled Stripe event type: {event['type']}")
|
||||||
return {"success": False}
|
return {"success": False}
|
||||||
|
|
||||||
# Retrieve the customer's details
|
# Retrieve the customer's details
|
||||||
|
|
|
@ -254,51 +254,45 @@ def md_content_config():
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def chat_client(search_config: SearchConfig, default_user2: KhojUser):
|
def chat_client(search_config: SearchConfig, default_user2: KhojUser):
|
||||||
# Initialize app state
|
return chat_client_builder(search_config, default_user2, require_auth=False)
|
||||||
state.config.search_type = search_config
|
|
||||||
state.SearchType = configure_search_types()
|
|
||||||
|
|
||||||
LocalMarkdownConfig.objects.create(
|
|
||||||
input_files=None,
|
|
||||||
input_filter=["tests/data/markdown/*.markdown"],
|
|
||||||
user=default_user2,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Index Markdown Content for Search
|
@pytest.fixture(scope="function")
|
||||||
all_files = fs_syncer.collect_files(user=default_user2)
|
def chat_client_with_auth(search_config: SearchConfig, default_user2: KhojUser):
|
||||||
state.content_index, _ = configure_content(
|
return chat_client_builder(search_config, default_user2, require_auth=True)
|
||||||
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize Processor from Config
|
|
||||||
if os.getenv("OPENAI_API_KEY"):
|
|
||||||
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
|
|
||||||
OpenAIProcessorConversationConfigFactory()
|
|
||||||
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
|
|
||||||
|
|
||||||
state.anonymous_mode = True
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
configure_routes(app)
|
|
||||||
configure_middleware(app)
|
|
||||||
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
|
||||||
return TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
|
def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
|
||||||
|
return chat_client_builder(search_config, default_user2, index_content=False, require_auth=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def chat_client_builder(search_config, user, index_content=True, require_auth=False):
|
||||||
# Initialize app state
|
# Initialize app state
|
||||||
state.config.search_type = search_config
|
state.config.search_type = search_config
|
||||||
state.SearchType = configure_search_types()
|
state.SearchType = configure_search_types()
|
||||||
|
|
||||||
|
if index_content:
|
||||||
|
LocalMarkdownConfig.objects.create(
|
||||||
|
input_files=None,
|
||||||
|
input_filter=["tests/data/markdown/*.markdown"],
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index Markdown Content for Search
|
||||||
|
all_files = fs_syncer.collect_files(user=user)
|
||||||
|
state.content_index, _ = configure_content(
|
||||||
|
state.content_index, state.config.content_type, all_files, state.search_models, user=user
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
if os.getenv("OPENAI_API_KEY"):
|
if os.getenv("OPENAI_API_KEY"):
|
||||||
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
|
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
|
||||||
OpenAIProcessorConversationConfigFactory()
|
OpenAIProcessorConversationConfigFactory()
|
||||||
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
|
UserConversationProcessorConfigFactory(user=user, setting=chat_model)
|
||||||
|
|
||||||
state.anonymous_mode = True
|
state.anonymous_mode = not require_auth
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import factory
|
import factory
|
||||||
|
from django.utils.timezone import make_aware
|
||||||
|
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
|
@ -90,4 +92,4 @@ class SubscriptionFactory(factory.django.DjangoModelFactory):
|
||||||
user = factory.SubFactory(UserFactory)
|
user = factory.SubFactory(UserFactory)
|
||||||
type = "standard"
|
type = "standard"
|
||||||
is_recurring = False
|
is_recurring = False
|
||||||
renewal_date = "2100-04-01"
|
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# Standard Modules
|
# Standard Modules
|
||||||
|
import os
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
@ -482,6 +483,21 @@ def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiU
|
||||||
assert response.json() == []
|
assert response.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.getenv("OPENAI_API_KEY") is None, reason="requires OPENAI_API_KEY")
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser):
|
||||||
|
# Arrange
|
||||||
|
headers = {"Authorization": f"Bearer {api_user2.token}"}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true', headers=headers)
|
||||||
|
no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert auth_response.status_code == 200
|
||||||
|
assert no_auth_response.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
def get_sample_files_data():
|
def get_sample_files_data():
|
||||||
return [
|
return [
|
||||||
("files", ("path/to/filename.org", "* practicing piano", "text/org")),
|
("files", ("path/to/filename.org", "* practicing piano", "text/org")),
|
||||||
|
|
|
@ -53,6 +53,7 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.skipif(os.getenv("SERPER_DEV_API_KEY") is None, reason="requires SERPER_DEV_API_KEY")
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_chat_with_online_content(chat_client):
|
def test_chat_with_online_content(chat_client):
|
||||||
|
@ -65,7 +66,7 @@ def test_chat_with_online_content(chat_client):
|
||||||
response_message = response_message.split("### compiled references")[0]
|
response_message = response_message.split("### compiled references")[0]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["http://www.paulgraham.com/greatwork.html", "Please set your SERPER_DEV_API_KEY"]
|
expected_responses = ["http://www.paulgraham.com/greatwork.html"]
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||||
"Expected links or serper not setup in response but got: " + response_message
|
"Expected links or serper not setup in response but got: " + response_message
|
||||||
|
|
Loading…
Reference in a new issue