[Multi-User Part 4]: Authenticate using API Tokens (#513)

###  New
- Use API keys to authenticate from Desktop, Obsidian, Emacs clients
- Create API, UI on web app config page to CRUD API Keys
- Create user API keys table and functions to CRUD them in Database

### 🧪 Improve
- Default to better search model, [gte-small](https://huggingface.co/thenlper/gte-small), to improve search quality
- Only load chat model to GPU if enough space, throw error on load failure
- Show encoding progress, truncate headings to max chars supported
- Add instruction to create db in Django DB setup Readme

### ⚙️ Fix
- Fix error handling when configure offline chat via Web UI
- Do not warn in anon mode about Google OAuth env vars not being set
- Fix path to load static files when server started from project root
This commit is contained in:
Debanjum 2023-10-26 12:33:03 -07:00 committed by GitHub
parent 4b6ec248a6
commit 9acc722f7f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
36 changed files with 692 additions and 564 deletions

View file

@ -37,6 +37,12 @@ make install # may need sudo
```
3. Create a database
### Create the khoj database
```bash
createdb khoj -U postgres
```
### Make migrations
This command will create the migrations for the database app. This command should be run whenever a new model is added to the database app or an existing model is modified (updated or deleted).

View file

@ -14,7 +14,7 @@ from pathlib import Path
import os
# Build paths inside the project like this: BASE_DIR / 'subdir'.
BASE_DIR = Path(__file__).resolve().parent.parent
BASE_DIR = Path(__file__).resolve().parent.parent.parent
# Quick-start development settings - unsuitable for production
@ -123,8 +123,8 @@ USE_TZ = True
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/4.2/howto/static-files/
STATIC_ROOT = os.path.join(BASE_DIR, "static")
STATICFILES_DIRS = [os.path.join(BASE_DIR, "khoj/interface/web")]
STATIC_ROOT = BASE_DIR / "static"
STATICFILES_DIRS = [BASE_DIR / "src/khoj/interface/web"]
STATIC_URL = "/static/"
# Default primary key field type

View file

@ -15,7 +15,7 @@ Including another URLconf
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from django.contrib import admin
from django.urls import path, include
from django.urls import path
from django.contrib.staticfiles.urls import staticfiles_urlpatterns
urlpatterns = [

View file

@ -1,3 +1,4 @@
import secrets
from typing import Type, TypeVar, List
from datetime import date
@ -16,6 +17,7 @@ from fastapi import HTTPException
from database.models import (
KhojUser,
GoogleUser,
KhojApiUser,
NotionConfig,
GithubConfig,
Embeddings,
@ -25,6 +27,7 @@ from database.models import (
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
)
from khoj.utils.helpers import generate_random_name
from khoj.utils.rawconfig import (
ConversationProcessorConfig as UserConversationProcessorConfig,
)
@ -52,6 +55,25 @@ async def set_notion_config(token: str, user: KhojUser):
return notion_config
async def create_khoj_token(user: KhojUser, name=None):
"Create Khoj API key for user"
token = f"kk-{secrets.token_urlsafe(32)}"
name = name or f"{generate_random_name().title()}'s Secret Key"
api_config = await KhojApiUser.objects.acreate(token=token, user=user, name=name)
await api_config.asave()
return api_config
def get_khoj_tokens(user: KhojUser):
"Get all Khoj API keys for user"
return list(KhojApiUser.objects.filter(user=user))
async def delete_khoj_token(user: KhojUser, token: str):
"Delete Khoj API Key for user"
await KhojApiUser.objects.filter(token=token, user=user).adelete()
async def get_or_create_user(token: dict) -> KhojUser:
user = await get_user_by_token(token)
if not user:

View file

@ -0,0 +1,24 @@
# Generated by Django 4.2.5 on 2023-10-26 17:02
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("database", "0008_alter_conversation_conversation_log"),
]
operations = [
migrations.CreateModel(
name="KhojApiUser",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("token", models.CharField(max_length=50, unique=True)),
("name", models.CharField(max_length=50)),
("accessed_at", models.DateTimeField(default=None, null=True)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
),
]

View file

@ -37,6 +37,15 @@ class GoogleUser(models.Model):
return self.name
class KhojApiUser(models.Model):
"""User issued API tokens to authenticate Khoj clients"""
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
token = models.CharField(max_length=50, unique=True)
name = models.CharField(max_length=50)
accessed_at = models.DateTimeField(null=True, default=None)
class NotionConfig(BaseModel):
token = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

View file

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="utf-8"?>
<svg width="800px" height="800px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M22 8.29344C22 11.7692 19.1708 14.5869 15.6807 14.5869C15.0439 14.5869 13.5939 14.4405 12.8885 13.8551L12.0067 14.7333C11.4883 15.2496 11.6283 15.4016 11.8589 15.652C11.9551 15.7565 12.0672 15.8781 12.1537 16.0505C12.1537 16.0505 12.8885 17.075 12.1537 18.0995C11.7128 18.6849 10.4783 19.5045 9.06754 18.0995L8.77362 18.3922C8.77362 18.3922 9.65538 19.4167 8.92058 20.4412C8.4797 21.0267 7.30403 21.6121 6.27531 20.5876L5.2466 21.6121C4.54119 22.3146 3.67905 21.9048 3.33616 21.6121L2.45441 20.7339C1.63143 19.9143 2.1115 19.0264 2.45441 18.6849L10.0963 11.0743C10.0963 11.0743 9.3615 9.90338 9.3615 8.29344C9.3615 4.81767 12.1907 2 15.6807 2C19.1708 2 22 4.81767 22 8.29344ZM15.681 10.4889C16.8984 10.4889 17.8853 9.50601 17.8853 8.29353C17.8853 7.08105 16.8984 6.09814 15.681 6.09814C14.4635 6.09814 13.4766 7.08105 13.4766 8.29353C13.4766 9.50601 14.4635 10.4889 15.681 10.4889Z" fill="#1C274C"/>
</svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View file

@ -1,5 +1,4 @@
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
<?xml version="1.0" encoding="utf-8"?>
<svg width="800px" height="800px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M15.197 3.35462C16.8703 1.67483 19.4476 1.53865 20.9536 3.05046C22.4596 4.56228 22.3239 7.14956 20.6506 8.82935L18.2268 11.2626M10.0464 14C8.54044 12.4882 8.67609 9.90087 10.3494 8.22108L12.5 6.06212" stroke="#1C274C" stroke-width="1.5" stroke-linecap="round"/>
<path d="M13.9536 10C15.4596 11.5118 15.3239 14.0991 13.6506 15.7789L11.2268 18.2121L8.80299 20.6454C7.12969 22.3252 4.55237 22.4613 3.0464 20.9495C1.54043 19.4377 1.67609 16.8504 3.34939 15.1706L5.77323 12.7373" stroke="#1C274C" stroke-width="1.5" stroke-linecap="round"/>
<path d="M9.16488 17.6505C8.92513 17.8743 8.73958 18.0241 8.54996 18.1336C7.62175 18.6695 6.47816 18.6695 5.54996 18.1336C5.20791 17.9361 4.87912 17.6073 4.22153 16.9498C3.56394 16.2922 3.23514 15.9634 3.03767 15.6213C2.50177 14.6931 2.50177 13.5495 3.03767 12.6213C3.23514 12.2793 3.56394 11.9505 4.22153 11.2929L7.04996 8.46448C7.70755 7.80689 8.03634 7.47809 8.37838 7.28062C9.30659 6.74472 10.4502 6.74472 11.3784 7.28061C11.7204 7.47809 12.0492 7.80689 12.7068 8.46448C13.3644 9.12207 13.6932 9.45086 13.8907 9.7929C14.4266 10.7211 14.4266 11.8647 13.8907 12.7929C13.7812 12.9825 13.6314 13.1681 13.4075 13.4078M10.5919 10.5922C10.368 10.8319 10.2182 11.0175 10.1087 11.2071C9.57284 12.1353 9.57284 13.2789 10.1087 14.2071C10.3062 14.5492 10.635 14.878 11.2926 15.5355C11.9502 16.1931 12.279 16.5219 12.621 16.7194C13.5492 17.2553 14.6928 17.2553 15.621 16.7194C15.9631 16.5219 16.2919 16.1931 16.9495 15.5355L19.7779 12.7071C20.4355 12.0495 20.7643 11.7207 20.9617 11.3787C21.4976 10.4505 21.4976 9.30689 20.9617 8.37869C20.7643 8.03665 20.4355 7.70785 19.7779 7.05026C19.1203 6.39267 18.7915 6.06388 18.4495 5.8664C17.5212 5.3305 16.3777 5.3305 15.4495 5.8664C15.2598 5.97588 15.0743 6.12571 14.8345 6.34955" stroke="#000000" stroke-width="2" stroke-linecap="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 777 B

After

Width:  |  Height:  |  Size: 1.4 KiB

View file

@ -89,6 +89,8 @@
// Generate backend API URL to execute query
let url = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true`;
const khojToken = await window.tokenAPI.getToken();
const headers = { 'Authorization': `Bearer ${khojToken}` };
let chat_body = document.getElementById("chat-body");
let new_response = document.createElement("div");
@ -113,7 +115,7 @@
chatInput.classList.remove("option-enabled");
// Call specified Khoj API which returns a streamed response of type text/plain
fetch(url)
fetch(url, { headers })
.then(response => {
const reader = response.body.getReader();
const decoder = new TextDecoder();
@ -217,7 +219,10 @@
async function loadChat() {
const hostURL = await window.hostURLAPI.getURL();
fetch(`${hostURL}/api/chat/history?client=web`)
const khojToken = await window.tokenAPI.getToken();
const headers = { 'Authorization': `Bearer ${khojToken}` };
fetch(`${hostURL}/api/chat/history?client=web`, { headers })
.then(response => response.json())
.then(data => {
if (data.detail) {
@ -243,7 +248,7 @@
return;
});
fetch(`${hostURL}/api/chat/options`)
fetch(`${hostURL}/api/chat/options`, { headers })
.then(response => response.json())
.then(data => {
// Render chat options, if any
@ -272,9 +277,9 @@
<img class="khoj-logo" src="./assets/icons/khoj-logo-sideways-500.png" alt="Khoj"></img>
</a>
<nav class="khoj-nav">
<a class="khoj-nav khoj-nav-selected" href="./chat.html">Chat</a>
<a class="khoj-nav" href="./index.html">Search</a>
<a class="khoj-nav" href="./config.html">⚙️</a>
<a class="khoj-nav khoj-nav-selected" href="./chat.html">💬 Chat</a>
<a class="khoj-nav" href="./index.html">🔎 Search</a>
<a class="khoj-nav" href="./config.html">⚙️ Settings</a>
</nav>
</div>

View file

@ -12,66 +12,85 @@
<script type="text/javascript" src="./assets/markdown-it.min.js"></script>
<body>
<div class="page">
<!--Add Header Logo and Nav Pane-->
<div class="khoj-header">
<a class="khoj-logo" href="./index.html">
<img class="khoj-logo" src="./assets/icons/khoj-logo-sideways-500.png" alt="Khoj"></img>
</a>
<nav class="khoj-nav">
<a class="khoj-nav" href="./chat.html">Chat</a>
<a class="khoj-nav" href="./index.html">Search</a>
<a class="khoj-nav khoj-nav-selected" href="./config.html">⚙️</a>
</nav>
</div>
<!--Add Header Logo and Nav Pane-->
<div class="khoj-header">
<a class="khoj-logo" href="./index.html">
<img class="khoj-logo" src="./assets/icons/khoj-logo-sideways-500.png" alt="Khoj"></img>
</a>
<nav class="khoj-nav">
<a class="khoj-nav" href="./chat.html">💬 Chat</a>
<a class="khoj-nav" href="./index.html">🔎 Search</a>
<a class="khoj-nav khoj-nav-selected" href="./config.html">⚙️ Settings</a>
</nav>
</div>
<div class="section-cards">
<div class="card configuration">
<div class="card-title-row">
<img class="card-icon" src="./assets/icons/link.svg" alt="File">
<h3 class="card-title">
Host
</h3>
<div class="card-description-row">
<div class="card configuration">
<div class="card-title-row">
<img class="card-icon" src="./assets/icons/link.svg" alt="Khoj Server URL">
<h3 class="card-title">
Server URL
</h3>
</div>
<div class="card-description-row">
<input id="khoj-host-url" class="card-input" type="text">
</div>
<div class="card-title-row">
<img class="card-icon" src="./assets/icons/key.svg" alt="Khoj Access Key">
<h3 class="card-title">
Access Key
</h3>
</div>
<div class="card-description-row">
<input id="khoj-access-key" class="card-input" type="text" placeholder="Enter key to access your Khoj">
</div>
</div>
<div class="card-description-row">
<input id="khoj-host-url" class="card-input" type="text">
</div>
<div class="card-title-row">
<img class="card-icon" src="./assets/icons/plaintext.svg" alt="File">
<h3 class="card-title">
Files
<button id="toggle-files" class="card-button">
<svg id="toggle-files-svg" xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 5v14M5 12l7 7 7-7"></path></svg>
</div>
<div class="card-description-row">
<div class="card configuration">
<div class="card-title-row">
<img class="card-icon" src="./assets/icons/plaintext.svg" alt="File">
<h3 class="card-title">
Files
<button id="toggle-files" class="card-button">
<svg id="toggle-files-svg" xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 5v14M5 12l7 7 7-7"></path></svg>
</button>
</h3>
</div>
<div class="card-description-row">
<div id="current-files"></div>
</div>
<div class="card-action-row">
<button id="update-file" class="card-button">
Add
<img class="add-files-icon" src="./assets/icons/circular-add.svg" alt="Add">
</button>
</h3>
</div>
</div>
<div class="card-description-row">
<div id="current-files"></div>
</div>
<div class="card-action-row">
<button id="update-file" class="card-button">
Add
<img class="add-files-icon" src="./assets/icons/circular-add.svg" alt="Add">
</button>
</div>
<div class="card-title-row">
<img class="card-icon" src="./assets/icons/folder.svg" alt="Folder">
<h3 class="card-title">
Folders
<button id="toggle-folders" class="card-button">
<svg id="toggle-folders-svg" xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 5v14M5 12l7 7 7-7"></path></svg>
</div>
<div class="card-description-row">
<div class="card configuration">
<div class="card-title-row">
<img class="card-icon" src="./assets/icons/folder.svg" alt="Folder">
<h3 class="card-title">
Folders
<button id="toggle-folders" class="card-button">
<svg id="toggle-folders-svg" xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 5v14M5 12l7 7 7-7"></path></svg>
</button>
</h3>
</div>
<div class="card-description-row">
<div id="current-folders"></div>
</div>
<div class="card-action-row">
<button id="update-folder" class="card-button">
Add
<img class="add-files-icon" src="./assets/icons/circular-add.svg" alt="Add">
</button>
</h3>
</div>
<div class="card-description-row">
<div id="current-folders"></div>
</div>
<div class="card-action-row">
<button id="update-folder" class="card-button">
Add
<img class="add-files-icon" src="./assets/icons/circular-add.svg" alt="Add">
</button>
</div>
</div>
</div>
<div class="section-action-row">
<div class="card-description-row">
<button id="sync-data">Sync</button>
</div>
@ -79,11 +98,10 @@
<input id="sync-force" type="checkbox" name="sync-force" value="force">
<label for="sync-force">Force Sync</label>
</div>
<div id="loading-bar" style="display: none;">
</div>
<div class="card-description-row">
<div id="sync-status"></div>
</div>
</div>
<div id="loading-bar" style="display: none;"></div>
<div class="card-description-row">
<div id="sync-status"></div>
</div>
</div>
</body>
@ -93,7 +111,7 @@
body {
display: grid;
grid-template-columns: 1fr;
grid-template-rows: 1fr auto auto auto minmax(80px, 100%);
grid-template-rows: 1fr auto;
font-size: small!important;
}
body > * {
@ -104,8 +122,7 @@
body {
display: grid;
grid-template-columns: 1fr min(70vw, 100%) 1fr;
grid-template-rows: 1fr auto auto auto minmax(80px, 100%);
padding-top: 60vw;
grid-template-rows: 80px auto;
}
body > * {
grid-column: 2;
@ -126,11 +143,6 @@
margin: 10px;
}
div.page {
padding: 0px;
margin: 0px;
}
svg {
transition: transform 0.3s ease-in-out;
}
@ -167,18 +179,18 @@
}
}
#khoj-host-url {
.card-input {
padding: 4px;
box-shadow: 0 0 2px 1px rgba(0, 0, 0, 0.2);
border: none;
width: 450px;
}
.card {
display: grid;
/* grid-template-rows: repeat(3, 1fr); */
gap: 8px;
padding: 24px 16px;
width: 100%;
width: 450px;
background: white;
border: 1px solid rgb(229, 229, 229);
border-radius: 4px;
@ -188,15 +200,15 @@
.section-cards {
display: grid;
grid-template-columns: repeat(1, 1fr);
gap: 16px;
justify-items: start;
justify-items: center;
margin: 0;
width: auto;
}
div.configuration {
width: auto;
.section-action-row {
display: grid;
grid-auto-flow: column;
gap: 16px;
height: fit-content;
}
.card-title-row {
@ -302,7 +314,6 @@
}
div.content-name {
width: 500px;
overflow-wrap: break-word;
}
@ -347,6 +358,12 @@
background-color: #ffcc00;
box-shadow: 0px 3px 0px #f9f5de;
}
.sync-force-toggle {
align-content: center;
display: grid;
grid-auto-flow: column;
gap: 4px;
}
</style>
<script>
var khojBannerSubmit = document.getElementById("khoj-banner-submit");

View file

@ -170,7 +170,10 @@
// Execute Search and Render Results
url = await createRequestUrl(query, type, results_count || 5, rerank);
fetch(url)
const khojToken = await window.tokenAPI.getToken();
const headers = { 'Authorization': `Bearer ${khojToken}` };
fetch(url, { headers })
.then(response => response.json())
.then(data => {
console.log(data);
@ -192,9 +195,11 @@
async function populate_type_dropdown() {
const hostURL = await window.hostURLAPI.getURL();
const khojToken = await window.tokenAPI.getToken();
const headers = { 'Authorization': `Bearer ${khojToken}` };
// Populate type dropdown field with enabled content types only
fetch(`${hostURL}/api/config/types`)
fetch(`${hostURL}/api/config/types`, { headers })
.then(response => response.json())
.then(enabled_types => {
// Show warning if no content types are enabled
@ -247,9 +252,9 @@
<img class="khoj-logo" src="./assets/icons/khoj-logo-sideways-500.png" alt="Khoj"></img>
</a>
<nav class="khoj-nav">
<a class="khoj-nav" href="./chat.html">Chat</a>
<a class="khoj-nav khoj-nav-selected" href="./index.html">Search</a>
<a class="khoj-nav" href="./config.html">⚙️</a>
<a class="khoj-nav" href="./chat.html">💬 Chat</a>
<a class="khoj-nav khoj-nav-selected" href="./index.html">🔎 Search</a>
<a class="khoj-nav" href="./config.html">⚙️ Settings</a>
</nav>
</div>

View file

@ -1,4 +1,4 @@
const { app, BrowserWindow, ipcMain } = require('electron');
const { app, BrowserWindow, ipcMain, Tray, Menu, nativeImage } = require('electron');
const todesktop = require("@todesktop/runtime");
todesktop.init();
@ -42,6 +42,10 @@ const schema = {
},
default: []
},
khojToken: {
type: 'string',
default: ''
},
hostURL: {
type: 'string',
default: KHOJ_URL
@ -63,7 +67,6 @@ const schema = {
};
var state = {}
const store = new Store({ schema });
console.log(store);
@ -168,7 +171,7 @@ function pushDataToKhoj (regenerate = false) {
if (!!formData?.entries()?.next().value) {
const hostURL = store.get('hostURL') || KHOJ_URL;
const headers = {
'x-api-key': 'secret'
'Authorization': `Bearer ${store.get("khojToken")}`
};
axios.post(`${hostURL}/api/v1/index/update?force=${regenerate}&client=desktop`, formData, { headers })
.then(response => {
@ -246,6 +249,15 @@ async function handleFileOpen (type) {
}
}
async function getToken () {
return store.get('khojToken');
}
async function setToken (event, token) {
store.set('khojToken', token);
return store.get('khojToken');
}
async function getFiles () {
return store.get('files');
}
@ -287,8 +299,9 @@ async function syncData (regenerate = false) {
}
}
const createWindow = () => {
const win = new BrowserWindow({
let win = null;
const createWindow = (tab = 'index.html') => {
win = new BrowserWindow({
width: 800,
height: 800,
// titleBarStyle: 'hidden',
@ -316,7 +329,7 @@ const createWindow = () => {
job.start();
win.loadFile('index.html')
win.loadFile(tab)
}
app.whenReady().then(() => {
@ -340,6 +353,9 @@ app.whenReady().then(() => {
ipcMain.handle('setURL', setURL);
ipcMain.handle('getURL', getURL);
ipcMain.handle('setToken', setToken);
ipcMain.handle('getToken', getToken);
ipcMain.handle('syncData', (event, regenerate) => {
syncData(regenerate);
});
@ -375,3 +391,33 @@ app.whenReady().then(() => {
app.on('window-all-closed', () => {
if (process.platform !== 'darwin') app.quit()
})
/*
** System Tray Icon
*/
let tray
openWindow = (page) => {
if (BrowserWindow.getAllWindows().length === 0) {
createWindow(page);
} else {
win.loadFile(page); win.show();
}
}
app.whenReady().then(() => {
const icon = nativeImage.createFromPath('assets/icons/favicon-20x20.png')
tray = new Tray(icon)
const contextMenu = Menu.buildFromTemplate([
{ label: 'Chat', type: 'normal', click: () => { openWindow('chat.html'); }},
{ label: 'Search', type: 'normal', click: () => { openWindow('index.html') }},
{ label: 'Configure', type: 'normal', click: () => { openWindow('config.html') }},
{ type: 'separator' },
{ label: 'Quit', type: 'normal', click: () => { app.quit() } }
])
tray.setToolTip('Khoj')
tray.setContextMenu(contextMenu)
})

View file

@ -47,3 +47,8 @@ contextBridge.exposeInMainWorld('hostURLAPI', {
contextBridge.exposeInMainWorld('syncDataAPI', {
syncData: (regenerate) => ipcRenderer.invoke('syncData', regenerate)
})
contextBridge.exposeInMainWorld('tokenAPI', {
setToken: (token) => ipcRenderer.invoke('setToken', token),
getToken: () => ipcRenderer.invoke('getToken')
})

View file

@ -181,6 +181,17 @@ urlInput.addEventListener('blur', async () => {
urlInput.value = url;
});
const khojKeyInput = document.getElementById('khoj-access-key');
(async function() {
const token = await window.tokenAPI.getToken();
khojKeyInput.value = token;
})();
khojKeyInput.addEventListener('blur', async () => {
const token = await window.tokenAPI.setToken(khojKeyInput.value.trim());
khojKeyInput.value = token;
});
const syncButton = document.getElementById('sync-data');
const syncForceToggle = document.getElementById('sync-force');
syncButton.addEventListener('click', async () => {

View file

@ -93,7 +93,7 @@
:group 'khoj
:type 'number)
(defcustom khoj-server-api-key "secret"
(defcustom khoj-api-key nil
"API Key to Khoj server."
:group 'khoj
:type 'string)
@ -246,26 +246,6 @@ for example), set this to the full interpreter path."
:type '(repeat string)
:group 'khoj)
(defcustom khoj-chat-model "gpt-3.5-turbo"
"Specify chat model to use for chat with khoj."
:type 'string
:group 'khoj)
(defcustom khoj-openai-api-key nil
"OpenAI API key used to configure chat on khoj server."
:type 'string
:group 'khoj)
(defcustom khoj-chat-offline nil
"Use offline model to chat with khoj."
:type 'boolean
:group 'khoj)
(defcustom khoj-offline-chat-model nil
"Specify chat model to use for offline chat with khoj."
:type 'string
:group 'khoj)
(defcustom khoj-auto-setup t
"Automate install, configure and start of khoj server.
Auto invokes setup steps on calling main entrypoint."
@ -319,8 +299,7 @@ Auto invokes setup steps on calling main entrypoint."
:filter (lambda (process msg)
(cond ((string-match (format "Uvicorn running on %s" khoj-server-url) msg)
(progn
(setq khoj--server-ready? t)
(khoj--server-configure)))
(setq khoj--server-ready? t)))
((string-match "Batches: " msg)
(when (string-match "\\([0-9]+\\.[0-9]+\\|\\([0-9]+\\)\\)%?" msg)
(message "khoj.el: %s updating index %s"
@ -383,106 +362,13 @@ Auto invokes setup steps on calling main entrypoint."
(when (not (khoj--server-started?))
(khoj--server-start)))
(defun khoj--get-directory-from-config (config keys &optional level)
"Extract directory under specified KEYS in CONFIG and trim it to LEVEL.
CONFIG is json obtained from Khoj config API."
(let ((item config))
(dolist (key keys)
(setq item (cdr (assoc key item))))
(-> item
(split-string "/")
(butlast (or level nil))
(string-join "/"))))
(defun khoj--server-configure ()
"Configure the Khoj server for search and chat."
(interactive)
(let* ((url-request-method "GET")
(current-config
(with-temp-buffer
(url-insert-file-contents (format "%s/api/config/data" khoj-server-url))
(ignore-error json-end-of-file (json-parse-buffer :object-type 'alist :array-type 'list :null-object json-null :false-object json-false))))
(default-config
(with-temp-buffer
(url-insert-file-contents (format "%s/api/config/data/default" khoj-server-url))
(ignore-error json-end-of-file (json-parse-buffer :object-type 'alist :array-type 'list :null-object json-null :false-object json-false))))
(default-chat-dir (khoj--get-directory-from-config default-config '(processor conversation conversation-logfile)))
(chat-model (or khoj-chat-model (alist-get 'chat-model (alist-get 'openai (alist-get 'conversation (alist-get 'processor default-config))))))
(enable-offline-chat (or khoj-chat-offline (alist-get 'enable-offline-chat (alist-get 'offline-chat (alist-get 'conversation (alist-get 'processor default-config))))))
(offline-chat-model (or khoj-offline-chat-model (alist-get 'chat-model (alist-get 'offline-chat (alist-get 'conversation (alist-get 'processor default-config))))))
(config (or current-config default-config)))
;; Configure processors
(cond
((not khoj-openai-api-key)
(let* ((processor (assoc 'processor config))
(conversation (assoc 'conversation processor))
(openai (assoc 'openai conversation)))
(when openai
;; Unset the `openai' field in the khoj conversation processor config
(message "khoj.el: Disable Chat using OpenAI as your OpenAI API key got removed from config")
(setcdr conversation (delq openai (cdr conversation)))
(push conversation (cdr processor))
(push processor config))))
;; If khoj backend isn't configured yet
((not current-config)
(message "khoj.el: Khoj not configured yet.")
(setq config (delq (assoc 'processor config) config))
(cl-pushnew `(processor . ((conversation . ((conversation-logfile . ,(format "%s/conversation.json" default-chat-dir))
(offline-chat . ((enable-offline-chat . ,enable-offline-chat)
(chat-model . ,offline-chat-model)))
(openai . ((chat-model . ,chat-model)
(api-key . ,khoj-openai-api-key)))))))
config))
;; Else if chat isn't configured in khoj backend
((not (alist-get 'conversation (alist-get 'processor config)))
(message "khoj.el: Chat not configured yet.")
(let ((new-processor-type (alist-get 'processor config)))
(setq new-processor-type (delq (assoc 'conversation new-processor-type) new-processor-type))
(cl-pushnew `(conversation . ((conversation-logfile . ,(format "%s/conversation.json" default-chat-dir))
(offline-chat . ((enable-offline-chat . ,enable-offline-chat)
(chat-model . ,offline-chat-model)))
(openai . ((chat-model . ,chat-model)
(api-key . ,khoj-openai-api-key)))))
new-processor-type)
(setq config (delq (assoc 'processor config) config))
(cl-pushnew `(processor . ,new-processor-type) config)))
;; Else if chat configuration in khoj backend has gone stale
((not (and (equal (alist-get 'api-key (alist-get 'openai (alist-get 'conversation (alist-get 'processor config)))) khoj-openai-api-key)
(equal (alist-get 'chat-model (alist-get 'openai (alist-get 'conversation (alist-get 'processor config)))) khoj-chat-model)
(equal (alist-get 'enable-offline-chat (alist-get 'offline-chat (alist-get 'conversation (alist-get 'processor config)))) enable-offline-chat)
(equal (alist-get 'chat-model (alist-get 'offline-chat (alist-get 'conversation (alist-get 'processor config)))) offline-chat-model)))
(message "khoj.el: Chat configuration has gone stale.")
(let* ((chat-directory (khoj--get-directory-from-config config '(processor conversation conversation-logfile)))
(new-processor-type (alist-get 'processor config)))
(setq new-processor-type (delq (assoc 'conversation new-processor-type) new-processor-type))
(cl-pushnew `(conversation . ((conversation-logfile . ,(format "%s/conversation.json" chat-directory))
(offline-chat . ((enable-offline-chat . ,enable-offline-chat)
(chat-model . ,offline-chat-model)))
(openai . ((chat-model . ,khoj-chat-model)
(api-key . ,khoj-openai-api-key)))))
new-processor-type)
(setq config (delq (assoc 'processor config) config))
(cl-pushnew `(processor . ,new-processor-type) config))))
;; Update server with latest configuration, if required
(cond ((not current-config)
(khoj--post-new-config config)
(message "khoj.el: ⚙️ Generated new khoj server configuration."))
((not (equal config current-config))
(khoj--post-new-config config)
(message "khoj.el: ⚙️ Updated khoj server configuration.")))))
(defun khoj-setup (&optional interact)
"Install, start and configure Khoj server. Get permission if INTERACT is non-nil."
"Install and start Khoj server. Get permission if INTERACT is non-nil."
(interactive "p")
;; Setup khoj server if not running
(let* ((not-started (not (khoj--server-started?)))
(permitted (if (and not-started interact)
(y-or-n-p "Could not connect to Khoj server. Should I install, start and configure it for you?")
(y-or-n-p "Could not connect to Khoj server. Should I install, start it for you?")
t)))
;; If user permits setup of khoj server from khoj.el
(when permitted
@ -491,12 +377,9 @@ CONFIG is json obtained from Khoj config API."
(khoj--server-setup))
;; Wait until server is ready
;; As server can be started but not ready to use/configure
;; As server can be started but not ready to use
(while (not khoj--server-ready?)
(sit-for 0.5))
;; Configure server once it's ready
(khoj--server-configure))))
(sit-for 0.5)))))
;; -------------------
@ -516,7 +399,7 @@ CONFIG is json obtained from Khoj config API."
(let ((url-request-method "POST")
(url-request-data (khoj--render-files-as-request-body files-to-index khoj--indexed-files boundary))
(url-request-extra-headers `(("content-type" . ,(format "multipart/form-data; boundary=%s" boundary))
("x-api-key" . ,khoj-server-api-key))))
("Authorization" . ,(format "Bearer %s" khoj-api-key)))))
(with-current-buffer
(url-retrieve (format "%s/api/v1/index/update?%s&force=%s&client=emacs" khoj-server-url type-query (or force "false"))
;; render response from indexing API endpoint on server
@ -690,19 +573,22 @@ Use `BOUNDARY' to separate files. This is sent to Khoj server as a POST request.
"Configure khoj server with provided CONFIG."
;; POST provided config to khoj server
(let ((url-request-method "POST")
(url-request-extra-headers '(("Content-Type" . "application/json")))
(url-request-extra-headers `(("Content-Type" . "application/json")
("Authorization" . ,(format "Bearer %s" khoj-api-key))))
(url-request-data (encode-coding-string (json-encode-alist config) 'utf-8))
(config-url (format "%s/api/config/data" khoj-server-url)))
(with-current-buffer (url-retrieve-synchronously config-url)
(buffer-string)))
;; Update index on khoj server after configuration update
(let ((khoj--server-ready? nil))
(let ((khoj--server-ready? nil)
(url-request-extra-headers `(("Authorization" . ,(format "\"Bearer %s\"" khoj-api-key)))))
(url-retrieve (format "%s/api/update?client=emacs" khoj-server-url) #'identity)))
(defun khoj--get-enabled-content-types ()
"Get content types enabled for search from API."
(let ((config-url (format "%s/api/config/types" khoj-server-url))
(url-request-method "GET"))
(url-request-method "GET")
(url-request-extra-headers `(("Authorization" . ,(format "Bearer %s" khoj-api-key)))))
(with-temp-buffer
(url-insert-file-contents config-url)
(thread-last
@ -722,7 +608,8 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE."
;; get json response from api
(with-current-buffer buffer-name
(let ((inhibit-read-only t)
(url-request-method "GET"))
(url-request-method "GET")
(url-request-extra-headers `(("Authorization" . ,(format "Bearer %s" khoj-api-key)))))
(erase-buffer)
(url-insert-file-contents query-url)))
;; render json response into formatted entries
@ -848,6 +735,7 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE."
"Send QUERY to Khoj Chat API."
(let* ((url-request-method "GET")
(encoded-query (url-hexify-string query))
(url-request-extra-headers `(("Authorization" . ,(format "Bearer %s" khoj-api-key))))
(query-url (format "%s/api/chat?q=%s&n=%s&client=emacs" khoj-server-url encoded-query khoj-results-count)))
(with-temp-buffer
(condition-case ex
@ -862,6 +750,7 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE."
(defun khoj--get-chat-history-api ()
"Send QUERY to Khoj Chat History API."
(let* ((url-request-method "GET")
(url-request-extra-headers `(("Authorization" . ,(format "Bearer %s" khoj-api-key))))
(query-url (format "%s/api/chat/history?client=emacs" khoj-server-url)))
(with-temp-buffer
(condition-case ex

View file

@ -141,7 +141,8 @@ export class KhojChatModal extends Modal {
async getChatHistory(): Promise<void> {
// Get chat history from Khoj backend
let chatUrl = `${this.setting.khojUrl}/api/chat/history?client=obsidian`;
let response = await request(chatUrl);
let headers = { "Authorization": `Bearer ${this.setting.khojApiKey}` };
let response = await request({ url: chatUrl, headers: headers });
let chatLogs = JSON.parse(response).response;
chatLogs.forEach((chatLog: any) => {
this.renderMessageWithReferences(chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created));
@ -167,7 +168,8 @@ export class KhojChatModal extends Modal {
method: "GET",
headers: {
"Access-Control-Allow-Origin": "*",
"Content-Type": "text/event-stream"
"Content-Type": "text/event-stream",
"Authorization": `Bearer ${this.setting.khojApiKey}`,
},
})

View file

@ -1,8 +1,8 @@
import { Notice, Plugin, TFile } from 'obsidian';
import { Notice, Plugin } from 'obsidian';
import { KhojSetting, KhojSettingTab, DEFAULT_SETTINGS } from 'src/settings'
import { KhojSearchModal } from 'src/search_modal'
import { KhojChatModal } from 'src/chat_modal'
import { configureKhojBackend, updateContentIndex } from './utils';
import { updateContentIndex } from './utils';
export default class Khoj extends Plugin {
@ -39,9 +39,9 @@ export default class Khoj extends Plugin {
id: 'chat',
name: 'Chat',
checkCallback: (checking) => {
if (!checking && this.settings.connectedToBackend && (!!this.settings.openaiApiKey || this.settings.enableOfflineChat))
if (!checking && this.settings.connectedToBackend)
new KhojChatModal(this.app, this.settings).open();
return !!this.settings.openaiApiKey || this.settings.enableOfflineChat;
return this.settings.connectedToBackend;
}
});
@ -69,17 +69,9 @@ export default class Khoj extends Plugin {
async loadSettings() {
// Load khoj obsidian plugin settings
this.settings = Object.assign({}, DEFAULT_SETTINGS, await this.loadData());
if (this.settings.autoConfigure) {
// Load, configure khoj server settings
await configureKhojBackend(this.app.vault, this.settings);
}
}
async saveSettings() {
if (this.settings.autoConfigure) {
await configureKhojBackend(this.app.vault, this.settings, false);
}
this.saveData(this.settings);
}

View file

@ -90,10 +90,11 @@ export class KhojSearchModal extends SuggestModal<SearchResult> {
// Query Khoj backend for search results
let encodedQuery = encodeURIComponent(query);
let searchUrl = `${this.setting.khojUrl}/api/search?q=${encodedQuery}&n=${this.setting.resultsCount}&r=${this.rerank}&client=obsidian`;
let headers = { 'Authorization': `Bearer ${this.setting.khojApiKey}` }
// Get search results for markdown and pdf files
let mdResponse = await request(`${searchUrl}&t=markdown`);
let pdfResponse = await request(`${searchUrl}&t=pdf`);
let mdResponse = await request({ url: `${searchUrl}&t=markdown`, headers: headers });
let pdfResponse = await request({ url: `${searchUrl}&t=pdf`, headers: headers });
// Parse search results
let mdData = JSON.parse(mdResponse)

View file

@ -3,22 +3,20 @@ import Khoj from 'src/main';
import { updateContentIndex } from './utils';
export interface KhojSetting {
enableOfflineChat: boolean;
openaiApiKey: string;
resultsCount: number;
khojUrl: string;
khojApiKey: string;
connectedToBackend: boolean;
autoConfigure: boolean;
lastSyncedFiles: TFile[];
}
export const DEFAULT_SETTINGS: KhojSetting = {
enableOfflineChat: false,
resultsCount: 6,
khojUrl: 'http://127.0.0.1:42110',
khojApiKey: '',
connectedToBackend: false,
autoConfigure: true,
openaiApiKey: '',
lastSyncedFiles: []
}
@ -49,21 +47,12 @@ export class KhojSettingTab extends PluginSettingTab {
containerEl.firstElementChild?.setText(this.getBackendStatusMessage());
}));
new Setting(containerEl)
.setName('OpenAI API Key')
.setDesc('Use OpenAI for Khoj Chat with your API key.')
.setName('Khoj API Key')
.setDesc('Use Khoj Cloud with your Khoj API Key')
.addText(text => text
.setValue(`${this.plugin.settings.openaiApiKey}`)
.setValue(`${this.plugin.settings.khojApiKey}`)
.onChange(async (value) => {
this.plugin.settings.openaiApiKey = value.trim();
await this.plugin.saveSettings();
}));
new Setting(containerEl)
.setName('Enable Offline Chat')
.setDesc('Chat privately without an internet connection. Enabling this will use offline chat even if OpenAI is configured.')
.addToggle(toggle => toggle
.setValue(this.plugin.settings.enableOfflineChat)
.onChange(async (value) => {
this.plugin.settings.enableOfflineChat = value;
this.plugin.settings.khojApiKey = value.trim();
await this.plugin.saveSettings();
}));
new Setting(containerEl)
@ -78,8 +67,8 @@ export class KhojSettingTab extends PluginSettingTab {
await this.plugin.saveSettings();
}));
new Setting(containerEl)
.setName('Auto Configure')
.setDesc('Automatically configure the Khoj backend.')
.setName('Auto Sync')
.setDesc('Automatically index your vault with Khoj.')
.addToggle(toggle => toggle
.setValue(this.plugin.settings.autoConfigure)
.onChange(async (value) => {
@ -88,7 +77,7 @@ export class KhojSettingTab extends PluginSettingTab {
}));
let indexVaultSetting = new Setting(containerEl);
indexVaultSetting
.setName('Index Vault')
.setName('Force Sync')
.setDesc('Manually force Khoj to re-index your Obsidian Vault.')
.addButton(button => button
.setButtonText('Update')

View file

@ -1,4 +1,4 @@
import { FileSystemAdapter, Notice, RequestUrlParam, request, Vault, Modal, TFile } from 'obsidian';
import { FileSystemAdapter, Notice, Vault, Modal, TFile } from 'obsidian';
import { KhojSetting } from 'src/settings'
export function getVaultAbsolutePath(vault: Vault): string {
@ -9,26 +9,6 @@ export function getVaultAbsolutePath(vault: Vault): string {
return '';
}
type OpenAIType = null | {
"chat-model": string;
"api-key": string;
};
type OfflineChatType = null | {
"chat-model": string;
"enable-offline-chat": boolean;
};
interface ProcessorData {
conversation: {
"conversation-logfile": string;
openai: OpenAIType;
"offline-chat": OfflineChatType;
"tokenizer": null | string;
"max-prompt-size": null | number;
};
}
function fileExtensionToMimeType (extension: string): string {
switch (extension) {
case 'pdf':
@ -78,7 +58,7 @@ export async function updateContentIndex(vault: Vault, setting: KhojSetting, las
const response = await fetch(`${setting.khojUrl}/api/v1/index/update?force=${regenerate}&client=obsidian`, {
method: 'POST',
headers: {
'x-api-key': 'secret',
'Authorization': `Bearer ${setting.khojApiKey}`,
},
body: formData,
});
@ -92,100 +72,6 @@ export async function updateContentIndex(vault: Vault, setting: KhojSetting, las
return files;
}
export async function configureKhojBackend(vault: Vault, setting: KhojSetting, notify: boolean = true) {
let khojConfigUrl = `${setting.khojUrl}/api/config/data`;
// Check if khoj backend is configured, note if cannot connect to backend
let khoj_already_configured = await request(khojConfigUrl)
.then(response => {
setting.connectedToBackend = true;
return response !== "null"
})
.catch(error => {
setting.connectedToBackend = false;
if (notify)
new Notice(`Ensure Khoj backend is running and Khoj URL is pointing to it in the plugin settings.\n\n${error}`);
})
// Short-circuit configuring khoj if unable to connect to khoj backend
if (!setting.connectedToBackend) return;
// Set index name from the path of the current vault
// Get default config fields from khoj backend
let defaultConfig = await request(`${khojConfigUrl}/default`).then(response => JSON.parse(response));
let khojDefaultChatDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["processor"]["conversation"]["conversation-logfile"]);
let khojDefaultOpenAIChatModelName = defaultConfig["processor"]["conversation"]["openai"]["chat-model"];
let khojDefaultOfflineChatModelName = defaultConfig["processor"]["conversation"]["offline-chat"]["chat-model"];
// Get current config if khoj backend configured, else get default config from khoj backend
await request(khoj_already_configured ? khojConfigUrl : `${khojConfigUrl}/default`)
.then(response => JSON.parse(response))
.then(data => {
let conversationLogFile = data?.["processor"]?.["conversation"]?.["conversation-logfile"] ?? `${khojDefaultChatDirectory}/conversation.json`;
let processorData: ProcessorData = {
"conversation": {
"conversation-logfile": conversationLogFile,
"openai": null,
"offline-chat": {
"chat-model": khojDefaultOfflineChatModelName,
"enable-offline-chat": setting.enableOfflineChat,
},
"tokenizer": null,
"max-prompt-size": null,
}
}
// If the Open AI API Key was configured in the plugin settings
if (!!setting.openaiApiKey) {
let openAIChatModel = data?.["processor"]?.["conversation"]?.["openai"]?.["chat-model"] ?? khojDefaultOpenAIChatModelName;
processorData = {
"conversation": {
"conversation-logfile": conversationLogFile,
"openai": {
"chat-model": openAIChatModel,
"api-key": setting.openaiApiKey,
},
"offline-chat": {
"chat-model": khojDefaultOfflineChatModelName,
"enable-offline-chat": setting.enableOfflineChat,
},
"tokenizer": null,
"max-prompt-size": null,
},
}
}
// Set khoj processor config to conversation processor config
data["processor"] = processorData;
// Save updated config and refresh index on khoj backend
updateKhojBackend(setting.khojUrl, data);
if (!khoj_already_configured)
console.log(`Khoj: Created khoj backend config:\n${JSON.stringify(data)}`)
else
console.log(`Khoj: Updated khoj backend config:\n${JSON.stringify(data)}`)
})
.catch(error => {
if (notify)
new Notice(`Failed to configure Khoj backend. Contact developer on Github.\n\nError: ${error}`);
})
}
export async function updateKhojBackend(khojUrl: string, khojConfig: Object) {
// POST khojConfig to khojConfigUrl
let requestContent: RequestUrlParam = {
url: `${khojUrl}/api/config/data`,
body: JSON.stringify(khojConfig),
method: 'POST',
contentType: 'application/json',
};
// Save khojConfig on khoj backend at khojConfigUrl
request(requestContent);
}
function getIndexDirectoryFromBackendConfig(filepath: string) {
return filepath.split("/").slice(0, -1).join("/");
}
export async function createNote(name: string, newLeaf = false): Promise<void> {
try {
let pathPrefix: string

View file

@ -4,6 +4,7 @@ import logging
import json
from enum import Enum
from typing import Optional
from fastapi import Request
import requests
import os
@ -45,9 +46,10 @@ class UserAuthenticationBackend(AuthenticationBackend):
def __init__(
self,
):
from database.models import KhojUser
from database.models import KhojUser, KhojApiUser
self.khojuser_manager = KhojUser.objects
self.khojapiuser_manager = KhojApiUser.objects
self._initialize_default_user()
super().__init__()
@ -59,13 +61,20 @@ class UserAuthenticationBackend(AuthenticationBackend):
password="default",
)
async def authenticate(self, request):
async def authenticate(self, request: Request):
current_user = request.session.get("user")
if current_user and current_user.get("email"):
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
elif state.anonymous_mode:
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
# Get bearer token from header
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
# Get user owning token
user_with_token = await self.khojapiuser_manager.filter(token=bearer_token).select_related("user").afirst()
if user_with_token:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").afirst()
if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)

View file

@ -288,6 +288,10 @@
</div>
</div>
<div class="section general-settings">
<div id="khoj-api-key-section" title="Use Khoj cloud with your Khoj API Key">
<button id="generate-api-key" onclick="generateAPIKey()">Generate API Key</button>
<div id="api-key-list"></div>
</div>
<div id="results-count" title="Number of items to show in search and use for chat response">
<label for="results-count-slider">Results Count: <span id="results-count-value">5</span></label>
<input type="range" id="results-count-slider" name="results-count-slider" min="1" max="10" step="1" value="5">
@ -311,6 +315,7 @@
</div>
</div>
<script>
function clearContentType(content_type) {
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/delete/config/data/content_type/' + content_type, {
@ -361,40 +366,42 @@
})
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
// Toggle the Enabled/Disabled UI based on the action/response.
var enableLocalLLLMButton = document.getElementById("set-enable-offline-chat");
var disableLocalLLLMButton = document.getElementById("clear-enable-offline-chat");
var configuredIcon = document.getElementById("configured-icon-conversation-enable-offline-chat");
var toggleEnableLocalLLLMButton = document.getElementById("toggle-enable-offline-chat");
toggleEnableLocalLLLMButton.classList.remove("enabled");
toggleEnableLocalLLLMButton.classList.add("disabled");
if (enable) {
enableLocalLLLMButton.classList.add("disabled");
enableLocalLLLMButton.classList.remove("enabled");
configuredIcon.classList.add("enabled");
configuredIcon.classList.remove("disabled");
disableLocalLLLMButton.classList.remove("disabled");
disableLocalLLLMButton.classList.add("enabled");
} else {
enableLocalLLLMButton.classList.remove("disabled");
enableLocalLLLMButton.classList.add("enabled");
configuredIcon.classList.remove("enabled");
configuredIcon.classList.add("disabled");
disableLocalLLLMButton.classList.add("disabled");
disableLocalLLLMButton.classList.remove("enabled");
}
if (data.status != "ok") {
featuresHintText.innerHTML = `🚨 Failed to ${enable ? "enable": "disable"} offline chat model! Inform server admins.`;
enable = !enable;
} else {
featuresHintText.classList.remove("show");
featuresHintText.innerHTML = "";
}
// Toggle the Enabled/Disabled UI based on the action/response.
var enableLocalLLLMButton = document.getElementById("set-enable-offline-chat");
var disableLocalLLLMButton = document.getElementById("clear-enable-offline-chat");
var configuredIcon = document.getElementById("configured-icon-conversation-enable-offline-chat");
var toggleEnableLocalLLLMButton = document.getElementById("toggle-enable-offline-chat");
toggleEnableLocalLLLMButton.classList.remove("enabled");
toggleEnableLocalLLLMButton.classList.add("disabled");
if (enable) {
enableLocalLLLMButton.classList.add("disabled");
enableLocalLLLMButton.classList.remove("enabled");
configuredIcon.classList.add("enabled");
configuredIcon.classList.remove("disabled");
disableLocalLLLMButton.classList.remove("disabled");
disableLocalLLLMButton.classList.add("enabled");
} else {
enableLocalLLLMButton.classList.remove("disabled");
enableLocalLLLMButton.classList.add("enabled");
configuredIcon.classList.remove("enabled");
configuredIcon.classList.add("disabled");
disableLocalLLLMButton.classList.add("disabled");
disableLocalLLLMButton.classList.remove("enabled");
}
})
}
@ -501,5 +508,57 @@
resultsCountSlider.value = storedResultsCount;
resultsCountValue.textContent = storedResultsCount;
}
function generateAPIKey() {
const apiKeyList = document.getElementById("api-key-list");
fetch('/auth/token', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
})
.then(response => response.json())
.then(tokenObj => {
apiKeyList.innerHTML += `
<div id="api-key-item-${tokenObj.token}" class="api-key-item">
<span class="api-key">${tokenObj.token}</span>
<button class="delete-api-key" onclick="deleteAPIKey('${tokenObj.token}')">Delete</button>
</div>
`;
});
}
function deleteAPIKey(token) {
const apiKeyList = document.getElementById("api-key-list");
fetch(`/auth/token?token=${token}`, {
method: 'DELETE',
})
.then(response => {
if (response.ok) {
const apiKeyItem = document.getElementById(`api-key-item-${token}`);
apiKeyList.removeChild(apiKeyItem);
}
});
}
function listApiKeys() {
const apiKeyList = document.getElementById("api-key-list");
fetch('/auth/token')
.then(response => response.json())
.then(tokens => {
apiKeyList.innerHTML = tokens.map(tokenObj =>
`
<div id="api-key-item-${tokenObj.token}" class="api-key-item">
<span class="api-key">${tokenObj.token}</span>
<button class="delete-api-key" onclick="deleteAPIKey('${tokenObj.token}')">Delete</button>
</div>
`)
.join("");
});
}
// List user's API keys on page load
listApiKeys();
</script>
{% endblock %}

View file

@ -98,9 +98,10 @@ def run():
# Mount Django and Static Files
app.mount("/django", django_app, name="django")
if not os.path.exists("static"):
os.mkdir("static")
app.mount("/static", StaticFiles(directory="static"), name="static")
static_dir = "static"
if not os.path.exists(static_dir):
os.mkdir(static_dir)
app.mount(f"/{static_dir}", StaticFiles(directory=static_dir), name=static_dir)
# Configure Middleware
configure_middleware(app)

View file

@ -6,17 +6,23 @@ logger = logging.getLogger(__name__)
def download_model(model_name: str):
try:
from gpt4all import GPT4All
import gpt4all
except ModuleNotFoundError as e:
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
# Use GPU for Chat Model, if available
try:
model = GPT4All(model_name=model_name, device="gpu")
logger.debug("Loaded chat model to GPU.")
except ValueError:
model = GPT4All(model_name=model_name)
logger.debug("Loaded chat model to CPU.")
# Download the chat model
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
return model
# Decide whether to load model to GPU or CPU
try:
# Check if machine has GPU and GPU has enough free memory to load the chat model
device = "gpu" if gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"]) else "cpu"
except ValueError:
device = "cpu"
# Now load the downloaded chat model onto appropriate device
chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False)
logger.debug(f"Loaded chat model to {device.upper()}.")
return chat_model

View file

@ -1,28 +1,17 @@
from typing import List
import torch
from langchain.embeddings import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
from khoj.utils.helpers import get_device
from khoj.utils.rawconfig import SearchResponse
class EmbeddingsModel:
def __init__(self):
self.model_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
encode_kwargs = {"normalize_embeddings": True}
# encode_kwargs = {}
if torch.cuda.is_available():
# Use CUDA GPU
device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
# Use Apple M1 Metal Acceleration
device = torch.device("mps")
else:
device = torch.device("cpu")
model_kwargs = {"device": device}
self.model_name = "thenlper/gte-small"
encode_kwargs = {"normalize_embeddings": True, "show_progress_bar": True}
model_kwargs = {"device": get_device()}
self.embeddings_model = HuggingFaceEmbeddings(
model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
)
@ -37,19 +26,7 @@ class EmbeddingsModel:
class CrossEncoderModel:
def __init__(self):
self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
if torch.cuda.is_available():
# Use CUDA GPU
device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
# Use Apple M1 Metal Acceleration
device = torch.device("mps")
else:
device = torch.device("cpu")
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=device)
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
def predict(self, query, hits: List[SearchResponse]):
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]

View file

@ -126,21 +126,21 @@ if not state.demo:
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
@api.get("/config/data", response_model=FullConfig)
@requires(["authenticated"], redirect="login_page")
@requires(["authenticated"])
def get_config_data(request: Request):
user = request.user.object if request.user.is_authenticated else None
enabled_content = EmbeddingsAdapters.get_unique_file_types(user)
user = request.user.object
EmbeddingsAdapters.get_unique_file_types(user)
return state.config
@api.post("/config/data")
@requires(["authenticated"], redirect="login_page")
@requires(["authenticated"])
async def set_config_data(
request: Request,
updated_config: FullConfig,
client: Optional[str] = None,
):
user = request.user.object if request.user.is_authenticated else None
user = request.user.object
await map_config_to_db(updated_config, user)
configuration_update_metadata = {}
@ -167,7 +167,7 @@ if not state.demo:
return state.config
@api.post("/config/data/content_type/github", status_code=200)
@requires(["authenticated"], redirect="login_page")
@requires(["authenticated"])
async def set_content_config_github_data(
request: Request,
updated_config: Union[GithubContentConfig, None],
@ -175,7 +175,7 @@ if not state.demo:
):
_initialize_config()
user = request.user.object if request.user.is_authenticated else None
user = request.user.object
await adapters.set_user_github_config(
user=user,
@ -194,7 +194,7 @@ if not state.demo:
return {"status": "ok"}
@api.post("/config/data/content_type/notion", status_code=200)
@requires(["authenticated"], redirect="login_page")
@requires(["authenticated"])
async def set_content_config_notion_data(
request: Request,
updated_config: Union[NotionContentConfig, None],
@ -202,7 +202,7 @@ if not state.demo:
):
_initialize_config()
user = request.user.object if request.user.is_authenticated else None
user = request.user.object
await adapters.set_notion_config(
user=user,
@ -220,13 +220,13 @@ if not state.demo:
return {"status": "ok"}
@api.post("/delete/config/data/content_type/{content_type}", status_code=200)
@requires(["authenticated"], redirect="login_page")
@requires(["authenticated"])
async def remove_content_config_data(
request: Request,
content_type: str,
client: Optional[str] = None,
):
user = request.user.object if request.user.is_authenticated else None
user = request.user.object
update_telemetry_state(
request=request,
@ -247,7 +247,7 @@ if not state.demo:
return {"status": "ok"}
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
@requires(["authenticated"], redirect="login_page")
@requires(["authenticated"])
async def remove_processor_conversation_config_data(
request: Request,
client: Optional[str] = None,
@ -267,7 +267,7 @@ if not state.demo:
return {"status": "ok"}
@api.post("/config/data/content_type/{content_type}", status_code=200)
@requires(["authenticated"], redirect="login_page")
@requires(["authenticated"])
async def set_content_config_data(
request: Request,
content_type: str,
@ -276,7 +276,7 @@ if not state.demo:
):
_initialize_config()
user = request.user.object if request.user.is_authenticated else None
user = request.user.object
content_object = map_config_to_object(content_type)
await adapters.set_text_content_config(user, content_object, updated_config)
@ -292,7 +292,7 @@ if not state.demo:
return {"status": "ok"}
@api.post("/config/data/processor/conversation/openai", status_code=200)
@requires(["authenticated"], redirect="login_page")
@requires(["authenticated"])
async def set_processor_openai_config_data(
request: Request,
updated_config: Union[OpenAIProcessorConfig, None],
@ -315,6 +315,7 @@ if not state.demo:
return {"status": "ok"}
@api.post("/config/data/processor/conversation/offline_chat", status_code=200)
@requires(["authenticated"])
async def set_processor_enable_offline_chat_config_data(
request: Request,
enable_offline_chat: bool,
@ -323,24 +324,29 @@ if not state.demo:
):
user = request.user.object
if enable_offline_chat:
conversation_config = ConversationProcessorConfig(
offline_chat=OfflineChatProcessorConfig(
enable_offline_chat=enable_offline_chat,
chat_model=offline_chat_model,
try:
if enable_offline_chat:
conversation_config = ConversationProcessorConfig(
offline_chat=OfflineChatProcessorConfig(
enable_offline_chat=enable_offline_chat,
chat_model=offline_chat_model,
)
)
)
await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config)
await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config)
offline_chat = await ConversationAdapters.get_offline_chat(user)
chat_model = offline_chat.chat_model
if state.gpt4all_processor_config is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
offline_chat = await ConversationAdapters.get_offline_chat(user)
chat_model = offline_chat.chat_model
if state.gpt4all_processor_config is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
else:
await sync_to_async(ConversationAdapters.clear_offline_chat_conversation_config)(user)
state.gpt4all_processor_config = None
else:
await sync_to_async(ConversationAdapters.clear_offline_chat_conversation_config)(user)
state.gpt4all_processor_config = None
except Exception as e:
logger.error(f"Error updating offline chat config: {e}", exc_info=True)
return {"status": "error", "message": str(e)}
update_telemetry_state(
request=request,
@ -360,11 +366,11 @@ def get_default_config_data():
@api.get("/config/types", response_model=List[str])
@requires(["authenticated"], redirect="login_page")
@requires(["authenticated"])
def get_config_types(
request: Request,
):
user = request.user.object if request.user.is_authenticated else None
user = request.user.object
enabled_file_types = EmbeddingsAdapters.get_unique_file_types(user)
@ -382,7 +388,7 @@ def get_config_types(
@api.get("/search", response_model=List[SearchResponse])
@requires(["authenticated"], redirect="login_page")
@requires(["authenticated"])
async def search(
q: str,
request: Request,
@ -396,7 +402,7 @@ async def search(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
user = request.user.object if request.user.is_authenticated else None
user = request.user.object
start_time = time.time()
# Run validation checks
@ -513,7 +519,7 @@ async def search(
@api.get("/update")
@requires(["authenticated"], redirect="login_page")
@requires(["authenticated"])
def update(
request: Request,
t: Optional[SearchType] = None,
@ -523,7 +529,7 @@ def update(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
user = request.user.object if request.user.is_authenticated else None
user = request.user.object
if not state.config:
error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
logger.warning(error_msg)
@ -557,7 +563,7 @@ def update(
@api.get("/chat/history")
@requires(["authenticated"], redirect="login_page")
@requires(["authenticated"])
def chat_history(
request: Request,
client: Optional[str] = None,
@ -585,7 +591,7 @@ def chat_history(
@api.get("/chat/options", response_class=Response)
@requires(["authenticated"], redirect="login_page")
@requires(["authenticated"])
async def chat_options(
request: Request,
client: Optional[str] = None,
@ -610,7 +616,7 @@ async def chat_options(
@api.get("/chat", response_class=Response)
@requires(["authenticated"], redirect="login_page")
@requires(["authenticated"])
async def chat(
request: Request,
q: str,

View file

@ -1,22 +1,29 @@
# Standard Packages
import logging
import json
import os
from typing import Optional
# External Packages
from fastapi import APIRouter
from starlette.config import Config
from starlette.requests import Request
from starlette.responses import HTMLResponse, RedirectResponse, Response
from starlette.authentication import requires
from authlib.integrations.starlette_client import OAuth, OAuthError
from google.oauth2 import id_token
from google.auth.transport import requests as google_requests
from database.adapters import get_or_create_user
# Internal Packages
from database.adapters import get_khoj_tokens, get_or_create_user, create_khoj_token, delete_khoj_token
from khoj.utils import state
logger = logging.getLogger(__name__)
auth_router = APIRouter()
if not os.environ.get("GOOGLE_CLIENT_ID") or not os.environ.get("GOOGLE_CLIENT_SECRET"):
if not state.anonymous_mode and not (os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET")):
logger.info("Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET environment variables to use Google OAuth")
else:
config = Config(environ=os.environ)
@ -39,6 +46,31 @@ async def login(request: Request):
return await oauth.google.authorize_redirect(request, redirect_uri)
@auth_router.post("/token")
@requires(["authenticated"], redirect="login_page")
async def generate_token(request: Request, token_name: Optional[str] = None) -> str:
"Generate API token for given user"
if token_name:
return await create_khoj_token(user=request.user.object, name=token_name)
else:
return await create_khoj_token(user=request.user.object)
@auth_router.get("/token")
@requires(["authenticated"], redirect="login_page")
def get_tokens(request: Request):
"Get API tokens enabled for given user"
tokens = get_khoj_tokens(user=request.user.object)
return tokens
@auth_router.delete("/token")
@requires(["authenticated"], redirect="login_page")
async def delete_token(request: Request, token: str) -> str:
"Delete API token for given user"
return await delete_khoj_token(user=request.user.object, token=token)
@auth_router.post("/redirect")
async def auth(request: Request):
form = await request.form()

View file

@ -4,9 +4,9 @@ from typing import Optional, Union, Dict
import asyncio
# External Packages
from fastapi import APIRouter, HTTPException, Header, Request, Response, UploadFile
from fastapi import APIRouter, Header, Request, Response, UploadFile
from pydantic import BaseModel
from khoj.routers.helpers import update_telemetry_state
from starlette.authentication import requires
# Internal Packages
from khoj.utils import state, constants
@ -17,6 +17,7 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from khoj.search_type import text_search, image_search
from khoj.routers.helpers import update_telemetry_state
from khoj.utils.yaml import save_config_to_file_updated_state
from khoj.utils.config import SearchModels
from khoj.utils.helpers import LRU, get_file_type
@ -57,10 +58,10 @@ class IndexerInput(BaseModel):
@indexer.post("/update")
@requires(["authenticated"])
async def update(
request: Request,
files: list[UploadFile],
x_api_key: str = Header(None),
force: bool = False,
t: Optional[Union[state.SearchType, str]] = None,
client: Optional[str] = None,
@ -68,9 +69,7 @@ async def update(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
user = request.user.object if request.user.is_authenticated else None
if x_api_key != "secret":
raise HTTPException(status_code=401, detail="Invalid API Key")
user = request.user.object
try:
logger.info(f"📬 Updating content index via API call by {client} client")
org_files: Dict[str, str] = {}

View file

@ -135,7 +135,7 @@ if not state.demo:
@web_client.get("/config/content_type/github", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def github_config_page(request: Request):
user = request.user.object if request.user.is_authenticated else None
user = request.user.object
current_github_config = get_user_github_config(user)
if current_github_config:
@ -164,7 +164,7 @@ if not state.demo:
@web_client.get("/config/content_type/notion", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def notion_config_page(request: Request):
user = request.user.object if request.user.is_authenticated else None
user = request.user.object
current_notion_config = get_user_notion_config(user)
current_config = NotionContentConfig(
@ -184,7 +184,7 @@ if not state.demo:
return templates.TemplateResponse("config.html", context={"request": request})
object = map_config_to_object(content_type)
user = request.user.object if request.user.is_authenticated else None
user = request.user.object
config = object.objects.filter(user=user).first()
if config == None:
config = object.objects.create(user=user)

View file

@ -6,12 +6,13 @@ import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Union, Any
from khoj.processor.conversation.gpt4all.utils import download_model
# External Packages
import torch
from khoj.utils.rawconfig import OfflineChatProcessorConfig
# Internal Packages
from khoj.processor.conversation.gpt4all.utils import download_model
logger = logging.getLogger(__name__)
@ -88,3 +89,4 @@ class GPT4AllProcessorModel:
except ValueError as e:
self.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
raise e

View file

@ -10,7 +10,7 @@ from os import path
import os
from pathlib import Path
import platform
import sys
import random
from time import perf_counter
import torch
from typing import Optional, Union, TYPE_CHECKING
@ -254,6 +254,18 @@ def log_telemetry(
return request_body
def get_device() -> torch.device:
"""Get device to run model on"""
if torch.cuda.is_available():
# Use CUDA GPU
return torch.device("cuda:0")
elif torch.backends.mps.is_available():
# Use Apple M1 Metal Acceleration
return torch.device("mps")
else:
return torch.device("cpu")
class ConversationCommand(str, Enum):
Default = "default"
General = "general"
@ -267,3 +279,29 @@ command_descriptions = {
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
}
def generate_random_name():
# List of adjectives and nouns to choose from
adjectives = [
"happy",
"irritated",
"annoyed",
"calm",
"brave",
"scared",
"energetic",
"chivalrous",
"kind",
"grumpy",
]
nouns = ["dog", "cat", "falcon", "whale", "turtle", "rabbit", "hamster", "snake", "spider", "elephant"]
# Select two random words from the lists
adjective = random.choice(adjectives)
noun = random.choice(nouns)
# Combine the words to form a name
name = f"{adjective} {noun}"
return name

View file

@ -1,7 +1,6 @@
# Standard Packages
import threading
from typing import List, Dict
from packaging import version
from collections import defaultdict
# External Packages
@ -11,7 +10,7 @@ from pathlib import Path
# Internal Packages
from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
from khoj.utils.helpers import LRU
from khoj.utils.helpers import LRU, get_device
from khoj.utils.rawconfig import FullConfig
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
@ -35,12 +34,4 @@ telemetry: List[Dict[str, str]] = []
demo: bool = False
khoj_version: str = None
anonymous_mode: bool = False
if torch.cuda.is_available():
# Use CUDA GPU
device = torch.device("cuda:0")
elif version.parse(torch.__version__) >= version.parse("1.13.0.dev") and torch.backends.mps.is_available():
# Use Apple M1 Metal Acceleration
device = torch.device("mps")
else:
device = torch.device("cpu")
device = get_device()

View file

@ -28,6 +28,7 @@ from khoj.utils import state, fs_syncer
from khoj.routers.indexer import configure_content
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from database.models import (
KhojApiUser,
LocalOrgConfig,
LocalMarkdownConfig,
LocalPlaintextConfig,
@ -76,13 +77,26 @@ def default_user2():
if KhojUser.objects.filter(username="default").exists():
return KhojUser.objects.get(username="default")
return UserFactory(
return KhojUser.objects.create(
username="default",
email="default@example.com",
password="default",
)
@pytest.mark.django_db
@pytest.fixture
def api_user(default_user):
if KhojApiUser.objects.filter(user=default_user).exists():
return KhojApiUser.objects.get(user=default_user)
return KhojApiUser.objects.create(
user=default_user,
name="api-key",
token="kk-secret",
)
@pytest.fixture(scope="session")
def search_models(search_config: SearchConfig):
search_models = SearchModels()
@ -176,7 +190,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser):
if os.getenv("OPENAI_API_KEY"):
OpenAIProcessorConversationConfigFactory(user=default_user2)
state.anonymous_mode = True
state.anonymous_mode = False
app = FastAPI()
@ -219,7 +233,7 @@ def fastapi_app():
def client(
content_config: ContentConfig,
search_config: SearchConfig,
default_user: KhojUser,
api_user: KhojApiUser,
):
state.config.content_type = content_config
state.config.search_type = search_config
@ -231,7 +245,7 @@ def client(
OrgToJsonl,
get_sample_data("org"),
regenerate=False,
user=default_user,
user=api_user.user,
)
state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False
@ -240,11 +254,11 @@ def client(
PlaintextToJsonl,
get_sample_data("plaintext"),
regenerate=False,
user=default_user,
user=api_user.user,
)
ConversationProcessorConfigFactory(user=default_user)
state.anonymous_mode = True
ConversationProcessorConfigFactory(user=api_user.user)
state.anonymous_mode = False
configure_routes(app)
configure_middleware(app)
@ -253,13 +267,8 @@ def client(
@pytest.fixture(scope="function")
def client_offline_chat(
search_config: SearchConfig,
content_config: ContentConfig,
default_user2: KhojUser,
):
def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state
state.config.content_type = md_content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
@ -269,9 +278,6 @@ def client_offline_chat(
user=default_user2,
)
# Index Markdown Content for Search
state.search_models.image_search = image_search.initialize_model(search_config.image)
all_files = fs_syncer.collect_files(user=default_user2)
configure_content(
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
@ -283,6 +289,8 @@ def client_offline_chat(
state.anonymous_mode = True
app = FastAPI()
configure_routes(app)
configure_middleware(app)
app.mount("/static", StaticFiles(directory=web_directory), name="static")

View file

@ -3,6 +3,7 @@ import os
from database.models import (
KhojUser,
KhojApiUser,
ConversationProcessorConfig,
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig,
@ -20,6 +21,15 @@ class UserFactory(factory.django.DjangoModelFactory):
uuid = factory.Faker("uuid4")
class ApiUserFactory(factory.django.DjangoModelFactory):
class Meta:
model = KhojApiUser
user = None
name = factory.Faker("name")
token = factory.Faker("password")
class ConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
class Meta:
model = ConversationProcessorConfig

View file

@ -22,49 +22,115 @@ from database.adapters import EmbeddingsAdapters
# Test
# ----------------------------------------------------------------------------------------------------
def test_search_with_invalid_content_type(client):
@pytest.mark.django_db(transaction=True)
def test_search_with_no_auth_key(client):
# Arrange
user_query = quote("How to call Khoj from Emacs?")
# Act
response = client.get(f"/api/search?q={user_query}&t=invalid_content_type")
response = client.get(f"/api/search?q={user_query}")
# Assert
assert response.status_code == 403
@pytest.mark.django_db(transaction=True)
def test_search_with_invalid_auth_key(client):
# Arrange
headers = {"Authorization": "Bearer invalid-token"}
user_query = quote("How to call Khoj from Emacs?")
# Act
response = client.get(f"/api/search?q={user_query}", headers=headers)
# Assert
assert response.status_code == 403
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_search_with_invalid_content_type(client):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
user_query = quote("How to call Khoj from Emacs?")
# Act
response = client.get(f"/api/search?q={user_query}&t=invalid_content_type", headers=headers)
# Assert
assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_search_with_valid_content_type(client):
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion"]:
headers = {"Authorization": "Bearer kk-secret"}
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion", "plaintext"]:
# Act
response = client.get(f"/api/search?q=random&t={content_type}")
response = client.get(f"/api/search?q=random&t={content_type}", headers=headers)
# Assert
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update_with_no_auth_key(client):
# Arrange
files = get_sample_files_data()
# Act
response = client.post("/api/v1/index/update", files=files)
# Assert
assert response.status_code == 403
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update_with_invalid_auth_key(client):
# Arrange
files = get_sample_files_data()
headers = {"Authorization": "Bearer kk-invalid-token"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
# Assert
assert response.status_code == 403
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_update_with_invalid_content_type(client):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.get(f"/api/update?t=invalid_content_type")
response = client.get(f"/api/update?t=invalid_content_type", headers=headers)
# Assert
assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_regenerate_with_invalid_content_type(client):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.get(f"/api/update?force=true&t=invalid_content_type")
response = client.get(f"/api/update?force=true&t=invalid_content_type", headers=headers)
# Assert
assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update(client):
# Arrange
files = get_sample_files_data()
headers = {"x-api-key": "secret"}
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
@ -74,29 +140,33 @@ def test_index_update(client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_regenerate_with_valid_content_type(client):
for content_type in ["all", "org", "markdown", "image", "pdf", "notion"]:
# Arrange
files = get_sample_files_data()
headers = {"x-api-key": "secret"}
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.post(f"/api/v1/index/update?t={content_type}", files=files, headers=headers)
# Assert
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_regenerate_with_github_fails_without_pat(client):
# Act
response = client.get(f"/api/update?force=true&t=github")
headers = {"Authorization": "Bearer kk-secret"}
response = client.get(f"/api/update?force=true&t=github", headers=headers)
# Arrange
files = get_sample_files_data()
headers = {"x-api-key": "secret"}
# Act
response = client.post(f"/api/v1/index/update?t=github", files=files, headers=headers)
# Assert
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: github"
@ -116,16 +186,17 @@ def test_get_configured_types_via_api(client, sample_org_data):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser):
def test_get_api_config_types(client, sample_org_data, default_user: KhojUser):
# Arrange
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user)
# Act
response = client.get(f"/api/config/types")
response = client.get(f"/api/config/types", headers=headers)
# Assert
assert response.status_code == 200
assert response.json() == ["all", "org", "image"]
assert response.json() == ["all", "org", "image", "plaintext"]
# ----------------------------------------------------------------------------------------------------
@ -135,6 +206,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
state.SearchType = configure_search_types(config)
original_config = state.config.content_type
state.config.content_type = None
state.anonymous_mode = True
configure_routes(fastapi_app)
client = TestClient(fastapi_app)
@ -154,6 +226,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
@pytest.mark.django_db(transaction=True)
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
@ -166,7 +239,7 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
for query, expected_image_name in query_expected_image_pairs:
# Act
response = client.get(f"/api/search?q={query}&n=1&t=image")
response = client.get(f"/api/search?q={query}&n=1&t=image", headers=headers)
# Assert
assert response.status_code == 200
@ -179,13 +252,14 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser):
def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
# Arrange
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user)
user_query = quote("How to git install application?")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true")
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true", headers=headers)
# Assert
assert response.status_code == 200
@ -197,19 +271,20 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_notes_search_with_only_filters(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data, default_user2: KhojUser
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data, default_user: KhojUser
):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(
OrgToJsonl,
sample_org_data,
regenerate=False,
user=default_user2,
user=default_user,
)
user_query = quote('+"Emacs" file:"*.org"')
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
# Assert
assert response.status_code == 200
@ -220,13 +295,14 @@ def test_notes_search_with_only_filters(
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_notes_search_with_include_filter(client, sample_org_data, default_user2: KhojUser):
def test_notes_search_with_include_filter(client, sample_org_data, default_user: KhojUser):
# Arrange
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user)
user_query = quote('How to git install application? +"Emacs"')
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
# Assert
assert response.status_code == 200
@ -237,18 +313,19 @@ def test_notes_search_with_include_filter(client, sample_org_data, default_user2
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_notes_search_with_exclude_filter(client, sample_org_data, default_user2: KhojUser):
def test_notes_search_with_exclude_filter(client, sample_org_data, default_user: KhojUser):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(
OrgToJsonl,
sample_org_data,
regenerate=False,
user=default_user2,
user=default_user,
)
user_query = quote('How to git install application? -"clone"')
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
# Assert
assert response.status_code == 200
@ -261,16 +338,17 @@ def test_notes_search_with_exclude_filter(client, sample_org_data, default_user2
@pytest.mark.django_db(transaction=True)
def test_different_user_data_not_accessed(client, sample_org_data, default_user: KhojUser):
# Arrange
headers = {"Authorization": "Bearer kk-token"} # Token for default_user2
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user)
user_query = quote("How to git install application?")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
# Assert
assert response.status_code == 200
assert response.status_code == 403
# assert actual response has no data as the default_user is different from the user making the query (anonymous)
assert len(response.json()) == 0
assert len(response.json()) == 1 and response.json()["detail"] == "Forbidden"
def get_sample_files_data():