Create Billing integration. Improve Settings pages on Desktop, Web apps (#537)

### Major
- Expose Billing via Stripe on Khoj Web app for Khoj Cloud subscription
  - Expose card on web app config page to manage subscription to Khoj cloud
  - Create API webhook, endpoints for subscription payments using Stripe
- Put Computer files to index into Card under Content section
  - Show file type icons for each indexed file in config card of web app
  - Enable deleting all indexed desktop files from Khoj via Desktop app
  - Create config page on web app to manage computer files indexed by Khoj
- Track data source (computer, github, notion) of each entry
  - Update content by source via API. Make web client use this API for config
  - Store the data source of each entry in database

### Cleanup
- Set content enabled status on update via config buttons on web app
- Delete deprecated content config pages for local files from web client
- Rename Sync button, Force Sync toggle to Save, Save All buttons

### Fixes
- Prevent Desktop app triggering multiple simultaneous syncs to server
- Upgrade langchain version since adding support for OCR-ing PDFs
- Bubble up content indexing errors to notify user on client apps
This commit is contained in:
Debanjum 2023-11-08 19:55:35 -08:00 committed by GitHub
commit 1d3bdf8fdb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
40 changed files with 758 additions and 414 deletions

View file

@ -9,6 +9,6 @@ The Github integration allows you to index as many repositories as you want. It'
## Use the Github plugin
1. Generate a [classic PAT (personal access token)](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) from [Github](https://github.com/settings/tokens) with `repo` and `admin:org` scopes at least.
2. Navigate to [http://localhost:42110/config/content_type/github](http://localhost:42110/config/content_type/github) to configure your Github settings. Enter in your PAT, along with details for each repository you want to index.
2. Navigate to [http://localhost:42110/config/content-source/github](http://localhost:42110/config/content-source/github) to configure your Github settings. Enter in your PAT, along with details for each repository you want to index.
3. Click `Save`. Go back to the settings page and click `Configure`.
4. Go to [http://localhost:42110/](http://localhost:42110/) and start searching!

View file

@ -8,7 +8,7 @@ We haven't setup a fancy integration with OAuth yet, so this integration still r
![setup_new_integration](https://github.com/khoj-ai/khoj/assets/65192171/b056e057-d4dc-47dc-aad3-57b59a22c68b)
3. Share all the workspaces that you want to integrate with the Khoj integration you just made in the previous step
![enable_workspace](https://github.com/khoj-ai/khoj/assets/65192171/98290303-b5b8-4cb0-b32c-f68c6923a3d0)
4. In the first step, you generated an API key. Use the newly generated API Key in your Khoj settings, by default at http://localhost:42110/config/content_type/notion. Click `Save`.
4. In the first step, you generated an API key. Use the newly generated API Key in your Khoj settings, by default at http://localhost:42110/config/content-source/notion. Click `Save`.
5. Click `Configure` in http://localhost:42110/config to index your Notion workspace(s).
That's it! You should be ready to start searching and chatting. Make sure you've configured your OpenAI API Key for chat.

View file

@ -55,7 +55,7 @@ dependencies = [
"torch == 2.0.1",
"uvicorn == 0.17.6",
"aiohttp == 3.8.5",
"langchain >= 0.0.187",
"langchain >= 0.0.331",
"requests >= 2.26.0",
"bs4 >= 0.0.1",
"anyio == 3.7.1",
@ -73,7 +73,8 @@ dependencies = [
"gunicorn == 21.2.0",
"lxml == 4.9.3",
"tzdata == 2023.3",
"rapidocr-onnxruntime == 1.3.8"
"rapidocr-onnxruntime == 1.3.8",
"stripe == 7.3.0",
]
dynamic = ["version"]

View file

@ -21,7 +21,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent.parent
# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = os.getenv("DJANGO_SECRET_KEY")
SECRET_KEY = os.getenv("KHOJ_DJANGO_SECRET_KEY")
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = os.getenv("DJANGO_DEBUG", "False") == "True"

View file

@ -1,8 +1,8 @@
from typing import Type, TypeVar, List
from datetime import date
from typing import Optional, Type, TypeVar, List
from datetime import date, datetime, timedelta
import secrets
from typing import Type, TypeVar, List
from datetime import date
from datetime import date, timezone
from django.db import models
from django.contrib.sessions.backends.db import SessionStore
@ -30,6 +30,7 @@ from database.models import (
GithubRepoConfig,
Conversation,
ChatModelOptions,
Subscription,
UserConversationConfig,
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
@ -103,6 +104,57 @@ async def create_google_user(token: dict) -> KhojUser:
return user
def get_user_subscription(email: str) -> Optional[Subscription]:
return Subscription.objects.filter(user__email=email).first()
async def set_user_subscription(
email: str, is_recurring=None, renewal_date=None, type="standard"
) -> Optional[Subscription]:
user_subscription = await Subscription.objects.filter(user__email=email).afirst()
if not user_subscription:
user = await get_user_by_email(email)
if not user:
return None
user_subscription = await Subscription.objects.acreate(
user=user, type=type, is_recurring=is_recurring, renewal_date=renewal_date
)
return user_subscription
elif user_subscription:
user_subscription.type = type
if is_recurring is not None:
user_subscription.is_recurring = is_recurring
if renewal_date is False:
user_subscription.renewal_date = None
elif renewal_date is not None:
user_subscription.renewal_date = renewal_date
await user_subscription.asave()
return user_subscription
else:
return None
def get_user_subscription_state(user_subscription: Subscription) -> str:
"""Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
"""
if not user_subscription:
return "trial"
elif user_subscription.type == Subscription.Type.TRIAL:
return "trial"
elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
return "subscribed"
elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
return "unsubscribed"
elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc):
return "expired"
return "invalid"
async def get_user_by_email(email: str) -> KhojUser:
return await KhojUser.objects.filter(email=email).afirst()
async def get_user_by_token(token: dict) -> KhojUser:
google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
if not google_user:
@ -287,13 +339,21 @@ class EntryAdapters:
return deleted_count
@staticmethod
def delete_all_entries(user: KhojUser, file_type: str = None):
def delete_all_entries_by_type(user: KhojUser, file_type: str = None):
if file_type is None:
deleted_count, _ = Entry.objects.filter(user=user).delete()
else:
deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete()
return deleted_count
@staticmethod
def delete_all_entries(user: KhojUser, file_source: str = None):
if file_source is None:
deleted_count, _ = Entry.objects.filter(user=user).delete()
else:
deleted_count, _ = Entry.objects.filter(user=user, file_source=file_source).delete()
return deleted_count
@staticmethod
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@ -318,8 +378,12 @@ class EntryAdapters:
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
@staticmethod
def aget_all_filenames(user: KhojUser):
return Entry.objects.filter(user=user).distinct("file_path").values_list("file_path", flat=True)
def aget_all_filenames_by_source(user: KhojUser, file_source: str):
return (
Entry.objects.filter(user=user, file_source=file_source)
.distinct("file_path")
.values_list("file_path", flat=True)
)
@staticmethod
async def adelete_all_entries(user: KhojUser):
@ -384,3 +448,7 @@ class EntryAdapters:
@staticmethod
def get_unique_file_types(user: KhojUser):
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
@staticmethod
def get_unique_file_source(user: KhojUser):
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct()

View file

@ -0,0 +1,21 @@
# Generated by Django 4.2.5 on 2023-11-07 07:24
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0011_merge_20231102_0138"),
]
operations = [
migrations.AddField(
model_name="entry",
name="file_source",
field=models.CharField(
choices=[("computer", "Computer"), ("notion", "Notion"), ("github", "Github")],
default="computer",
max_length=30,
),
),
]

View file

@ -0,0 +1,37 @@
# Generated by Django 4.2.5 on 2023-11-09 01:27
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("database", "0012_entry_file_source"),
]
operations = [
migrations.CreateModel(
name="Subscription",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"type",
models.CharField(
choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20
),
),
("is_recurring", models.BooleanField(default=False)),
("renewal_date", models.DateTimeField(default=None, null=True)),
(
"user",
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
),
],
options={
"abstract": False,
},
),
]

View file

@ -46,6 +46,17 @@ class KhojApiUser(models.Model):
accessed_at = models.DateTimeField(null=True, default=None)
class Subscription(BaseModel):
class Type(models.TextChoices):
TRIAL = "trial"
STANDARD = "standard"
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
is_recurring = models.BooleanField(default=False)
renewal_date = models.DateTimeField(null=True, default=None)
class NotionConfig(BaseModel):
token = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
@ -131,11 +142,17 @@ class Entry(BaseModel):
GITHUB = "github"
CONVERSATION = "conversation"
class EntrySource(models.TextChoices):
COMPUTER = "computer"
NOTION = "notion"
GITHUB = "github"
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
embeddings = VectorField(dimensions=384)
raw = models.TextField()
compiled = models.TextField()
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
file_source = models.CharField(max_length=30, choices=EntrySource.choices, default=EntrySource.COMPUTER)
file_type = models.CharField(max_length=30, choices=EntryType.choices, default=EntryType.PLAINTEXT)
file_path = models.CharField(max_length=400, default=None, null=True, blank=True)
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)

View file

@ -192,9 +192,9 @@
.then(response => {
const reader = response.body.getReader();
const decoder = new TextDecoder();
let references = null;
function readStream() {
let references = null;
reader.read().then(({ done, value }) => {
if (done) {
// Evaluate the contents of new_response_text.innerHTML after all the data has been streamed

View file

@ -91,11 +91,13 @@
</div>
<div class="section-action-row">
<div class="card-description-row">
<button id="sync-data">Sync</button>
<button id="sync-data" class="sync-data">💾 Save</button>
</div>
<div class="card-description-row sync-force-toggle">
<input id="sync-force" type="checkbox" name="sync-force" value="force">
<label for="sync-force">Force Sync</label>
<div class="card-description-row">
<button id="sync-force" class="sync-data">💾 Save All</button>
</div>
<div class="card-description-row">
<button id="delete-all" class="sync-data">🗑️ Delete All</button>
</div>
</div>
<div id="loading-bar" style="display: none;"></div>
@ -336,7 +338,7 @@
padding: 4px;
cursor: pointer;
}
#sync-data {
button.sync-data {
background-color: var(--primary);
border: none;
color: var(--main-text-color);
@ -351,7 +353,7 @@
box-shadow: 0px 5px 0px var(--background-color);
}
#sync-data:hover {
button.sync-data:hover {
background-color: var(--primary-hover);
box-shadow: 0px 3px 0px var(--background-color);
}

View file

@ -67,6 +67,7 @@ const schema = {
}
};
const syncing = false;
var state = {}
const store = new Store({ schema });
@ -110,6 +111,15 @@ function filenameToMimeType (filename) {
}
function pushDataToKhoj (regenerate = false) {
// Don't sync if token or hostURL is not set or if already syncing
if (store.get('khojToken') === '' || store.get('hostURL') === '' || syncing === true) {
const win = BrowserWindow.getAllWindows()[0];
if (win) win.webContents.send('update-state', state);
return;
} else {
syncing = true;
}
let filesToPush = [];
const files = store.get('files') || [];
const folders = store.get('folders') || [];
@ -192,11 +202,13 @@ function pushDataToKhoj (regenerate = false) {
})
.finally(() => {
// Syncing complete
syncing = false;
const win = BrowserWindow.getAllWindows()[0];
if (win) win.webContents.send('update-state', state);
});
} else {
// Syncing complete
syncing = false;
const win = BrowserWindow.getAllWindows()[0];
if (win) win.webContents.send('update-state', state);
}
@ -306,6 +318,19 @@ async function syncData (regenerate = false) {
}
}
async function deleteAllFiles () {
try {
store.set('files', []);
store.set('folders', []);
pushDataToKhoj(true);
const date = new Date();
console.log('Pushing data to Khoj at: ', date);
} catch (err) {
console.error(err);
}
}
let firstRun = true;
let win = null;
const createWindow = (tab = 'chat.html') => {
@ -386,6 +411,7 @@ app.whenReady().then(() => {
ipcMain.handle('syncData', (event, regenerate) => {
syncData(regenerate);
});
ipcMain.handle('deleteAllFiles', deleteAllFiles);
createWindow()

View file

@ -45,7 +45,8 @@ contextBridge.exposeInMainWorld('hostURLAPI', {
})
contextBridge.exposeInMainWorld('syncDataAPI', {
syncData: (regenerate) => ipcRenderer.invoke('syncData', regenerate)
syncData: (regenerate) => ipcRenderer.invoke('syncData', regenerate),
deleteAllFiles: () => ipcRenderer.invoke('deleteAllFiles')
})
contextBridge.exposeInMainWorld('tokenAPI', {

View file

@ -196,9 +196,19 @@ khojKeyInput.addEventListener('blur', async () => {
});
const syncButton = document.getElementById('sync-data');
const syncForceToggle = document.getElementById('sync-force');
syncButton.addEventListener('click', async () => {
loadingBar.style.display = 'block';
const regenerate = syncForceToggle.checked;
await window.syncDataAPI.syncData(regenerate);
await window.syncDataAPI.syncData(false);
});
const syncForceButton = document.getElementById('sync-force');
syncForceButton.addEventListener('click', async () => {
loadingBar.style.display = 'block';
await window.syncDataAPI.syncData(true);
});
const deleteAllButton = document.getElementById('delete-all');
deleteAllButton.addEventListener('click', async () => {
loadingBar.style.display = 'block';
await window.syncDataAPI.deleteAllFiles();
});

View file

@ -1,5 +1,4 @@
# Standard Packages
import sys
import logging
import json
from enum import Enum
@ -109,7 +108,6 @@ def configure_server(
state.search_models = configure_search(state.search_models, state.config.search_type)
initialize_content(regenerate, search_type, init, user)
except Exception as e:
logger.error(f"🚨 Failed to configure search models", exc_info=True)
raise e
finally:
state.config_lock.release()
@ -125,7 +123,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
else:
logger.info("📬 Updating content index...")
all_files = collect_files(user=user)
state.content_index = configure_content(
state.content_index, status = configure_content(
state.content_index,
state.config.content_type,
all_files,
@ -134,8 +132,9 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
search_type,
user=user,
)
if not status:
raise RuntimeError("Failed to update content index")
except Exception as e:
logger.error(f"🚨 Failed to index content", exc_info=True)
raise e
@ -146,10 +145,14 @@ def configure_routes(app):
from khoj.routers.web_client import web_client
from khoj.routers.indexer import indexer
from khoj.routers.auth import auth_router
from khoj.routers.subscription import subscription_router
app.include_router(api, prefix="/api")
app.include_router(api_beta, prefix="/api/beta")
app.include_router(indexer, prefix="/api/v1/index")
if state.billing_enabled:
logger.info("💳 Enabled Billing")
app.include_router(subscription_router, prefix="/api/subscription")
app.include_router(web_client)
app.include_router(auth_router, prefix="/auth")
@ -165,13 +168,15 @@ def update_search_index():
logger.info("📬 Updating content index via Scheduler")
for user in get_all_users():
all_files = collect_files(user=user)
state.content_index = configure_content(
state.content_index, success = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models, user=user
)
all_files = collect_files(user=None)
state.content_index = configure_content(
state.content_index, success = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models, user=None
)
if not success:
raise RuntimeError("Failed to update content index")
logger.info("📪 Content index updated via Scheduler")
except Exception as e:
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True)

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

View file

@ -209,23 +209,27 @@
border: none;
color: var(--flower);
padding: 4px;
width: 32px;
margin-bottom: 0px
}
div.file-element {
display: grid;
grid-template-columns: 1fr auto;
grid-template-columns: 1fr 5fr 1fr;
border: 1px solid rgb(229, 229, 229);
border-radius: 4px;
box-shadow: 0px 1px 3px 0px rgba(0,0,0,0.1),0px 1px 2px -1px rgba(0,0,0,0.8);
padding: 4px;
padding: 4px 0;
margin-bottom: 8px;
justify-items: center;
align-items: center;
}
div.remove-button-container {
text-align: right;
}
button.card-button.happy {
.card-button.happy {
color: var(--leaf);
}

View file

@ -187,9 +187,9 @@
.then(response => {
const reader = response.body.getReader();
const decoder = new TextDecoder();
let references = null;
function readStream() {
let references = null;
reader.read().then(({ done, value }) => {
if (done) {
// Evaluate the contents of new_response_text.innerHTML after all the data has been streamed

View file

@ -3,23 +3,57 @@
<div class="page">
<div class="section">
<h2 class="section-title">Plugins</h2>
<h2 class="section-title">Content</h2>
<div class="section-cards">
<div class="card">
<div class="card-title-row">
<img class="card-icon" src="/static/assets/icons/computer.png" alt="Computer">
<h3 id="card-title-computer" class="card-title">
Files
<img id="configured-icon-computer"
style="display: {% if not current_model_state.computer %}none{% endif %}"
class="configured-icon"
src="/static/assets/icons/confirm-icon.svg"
alt="Configured">
</h3>
</div>
<div class="card-description-row">
<p class="card-description">Manage files from your computer</p>
</div>
<div class="card-action-row">
<a class="card-button" href="/config/content-source/computer">
{% if current_model_state.computer %}
Update
{% else %}
Setup
{% endif %}
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
</div>
<div id="clear-computer" class="card-action-row"
style="display: {% if not current_model_state.computer %}none{% endif %}">
<button class="card-button" onclick="clearContentType('computer')">
Disable
</button>
</div>
</div>
<div class="card">
<div class="card-title-row">
<img class="card-icon" src="/static/assets/icons/github.svg" alt="Github">
<h3 class="card-title">
Github
{% if current_model_state.github == True %}
<img id="configured-icon-github" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% endif %}
<img id="configured-icon-github"
class="configured-icon"
src="/static/assets/icons/confirm-icon.svg"
alt="Configured"
style="display: {% if not current_model_state.github %}none{% endif %}">
</h3>
</div>
<div class="card-description-row">
<p class="card-description">Set repositories to index</p>
</div>
<div class="card-action-row">
<a class="card-button" href="/config/content_type/github">
<a class="card-button" href="/config/content-source/github">
{% if current_model_state.github %}
Update
{% else %}
@ -28,30 +62,32 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
</div>
{% if current_model_state.github %}
<div id="clear-github" class="card-action-row">
<button class="card-button" onclick="clearContentType('github')">
Disable
</button>
</div>
{% endif %}
<div id="clear-github"
class="card-action-row"
style="display: {% if not current_model_state.github %}none{% endif %}">
<button class="card-button" onclick="clearContentType('github')">
Disable
</button>
</div>
</div>
<div class="card">
<div class="card-title-row">
<img class="card-icon" src="/static/assets/icons/notion.svg" alt="Notion">
<h3 class="card-title">
Notion
{% if current_model_state.notion == True %}
<img id="configured-icon-notion" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% endif %}
<img id="configured-icon-notion"
class="configured-icon"
src="/static/assets/icons/confirm-icon.svg"
alt="Configured"
style="display: {% if not current_model_state.notion %}none{% endif %}">
</h3>
</div>
<div class="card-description-row">
<p class="card-description">Configure your settings from Notion</p>
<p class="card-description">Sync your Notion pages</p>
</div>
<div class="card-action-row">
<a class="card-button" href="/config/content_type/notion">
{% if current_model_state.content %}
<a class="card-button" href="/config/content-source/notion">
{% if current_model_state.notion %}
Update
{% else %}
Setup
@ -59,13 +95,13 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
</div>
{% if current_model_state.notion %}
<div id="clear-notion" class="card-action-row">
<button class="card-button" onclick="clearContentType('notion')">
Disable
</button>
</div>
{% endif %}
<div id="clear-notion"
class="card-action-row"
style="display: {% if not current_model_state.notion %}none{% endif %}">
<button class="card-button" onclick="clearContentType('notion')">
Disable
</button>
</div>
</div>
</div>
</div>
@ -77,7 +113,7 @@
<div class="card-title-row">
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
<h3 class="card-title">
Chat Model
Chat
</h3>
</div>
<div class="card-description-row">
@ -122,16 +158,69 @@
</div>
</div>
</div>
{% if billing_enabled %}
<div class="section">
<h2 class="section-title">Manage Data</h2>
<div class="section-manage-files">
<div id="delete-all-files" class="delete-all=files">
<button id="delete-all-files" type="submit" title="Delete all indexed files">🗑️ Remove All</button>
</div>
<div class="indexed-files">
<h2 class="section-title">Billing</h2>
<div class="section-cards">
<div class="card">
<div class="card-title-row">
<img class="card-icon" src="/static/assets/icons/credit-card.png" alt="Credit Card">
<h3 class="card-title">
<span>Subscription</span>
<img id="configured-icon-subscription"
style="display: {% if subscription_state == 'trial' or subscription_state == 'expired' %}none{% endif %}"
class="configured-icon"
src="/static/assets/icons/confirm-icon.svg"
alt="Configured">
</h3>
</div>
<div class="card-description-row">
<p id="trial-description"
class="card-description"
style="display: {% if subscription_state != 'trial' %}none{% endif %}">
Subscribe to Khoj Cloud
</p>
<p id="unsubscribe-description"
class="card-description"
style="display: {% if subscription_state != 'subscribed' %}none{% endif %}">
You are <b>subscribed</b> to Khoj Cloud. Subscription will <b>renew</b> on <b>{{ subscription_renewal_date }}</b>
</p>
<p id="resubscribe-description"
class="card-description"
style="display: {% if subscription_state != 'unsubscribed' %}none{% endif %}">
You are <b>subscribed</b> to Khoj Cloud. Subscription will <b>expire</b> on <b>{{ subscription_renewal_date }}</b>
</p>
<p id="expire-description"
class="card-description"
style="display: {% if subscription_state != 'expired' %}none{% endif %}">
Subscribe to Khoj Cloud. Subscription <b>expired</b> on <b>{{ subscription_renewal_date }}</b>
</p>
</div>
<div class="card-action-row">
<button id="unsubscribe-button"
class="card-button"
onclick="unsubscribe()"
style="display: {% if subscription_state != 'subscribed' %}none{% endif %};">
Unsubscribe
</button>
<button id="resubscribe-button"
class="card-button happy"
onclick="resubscribe()"
style="display: {% if subscription_state != 'unsubscribed' %}none{% endif %};">
Resubscribe
</button>
<a id="subscribe-button"
class="card-button happy"
href="{{ khoj_cloud_subscription_url }}?prefilled_email={{ username }}"
style="display: {% if subscription_state == 'subscribed' or subscription_state == 'unsubscribed' %}none{% endif %};">
Subscribe
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
</div>
</div>
</div>
</div>
{% endif %}
<div class="section general-settings">
<div id="results-count" title="Number of items to show in search and use for chat response">
<label for="results-count-slider">Results Count: <span id="results-count-value">5</span></label>
@ -176,8 +265,9 @@
})
};
function clearContentType(content_type) {
fetch('/api/config/data/content_type/' + content_type, {
function clearContentType(content_source) {
fetch('/api/config/data/content-source/' + content_source, {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
@ -186,22 +276,54 @@
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
var contentTypeClearButton = document.getElementById("clear-" + content_type);
contentTypeClearButton.style.display = "none";
var configuredIcon = document.getElementById("configured-icon-" + content_type);
if (configuredIcon) {
configuredIcon.style.display = "none";
}
var misconfiguredIcon = document.getElementById("misconfigured-icon-" + content_type);
if (misconfiguredIcon) {
misconfiguredIcon.style.display = "none";
}
document.getElementById("configured-icon-" + content_source).style.display = "none";
document.getElementById("clear-" + content_source).style.display = "none";
} else {
document.getElementById("configured-icon-" + content_source).style.display = "";
document.getElementById("clear-" + content_source).style.display = "";
}
})
};
function unsubscribe() {
fetch('/api/subscription?operation=cancel&email={{username}}', {
method: 'PATCH',
headers: {
'Content-Type': 'application/json',
},
})
.then(response => response.json())
.then(data => {
if (data.success) {
document.getElementById("unsubscribe-description").style.display = "none";
document.getElementById("unsubscribe-button").style.display = "none";
document.getElementById("resubscribe-description").style.display = "";
document.getElementById("resubscribe-button").style.display = "";
}
})
}
function resubscribe() {
fetch('/api/subscription?operation=resubscribe&email={{username}}', {
method: 'PATCH',
headers: {
'Content-Type': 'application/json',
},
})
.then(response => response.json())
.then(data => {
if (data.success) {
document.getElementById("resubscribe-description").style.display = "none";
document.getElementById("resubscribe-button").style.display = "none";
document.getElementById("unsubscribe-description").style.display = "";
document.getElementById("unsubscribe-button").style.display = "";
}
})
}
var configure = document.getElementById("configure");
configure.addEventListener("click", function(event) {
event.preventDefault();
@ -243,6 +365,7 @@
if (data.detail != null) {
throw new Error(data.detail);
}
document.getElementById("status").innerHTML = emoji + " " + successText;
document.getElementById("status").style.display = "block";
button.disabled = false;
@ -255,6 +378,26 @@
button.disabled = false;
button.innerHTML = '⚠️ Unsuccessful';
});
content_sources = ["computer", "github", "notion"];
content_sources.forEach(content_source => {
fetch(`/api/config/data/${content_source}`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
}
})
.then(response => response.json())
.then(data => {
if (data.length > 0) {
document.getElementById("configured-icon-" + content_source).style.display = "";
document.getElementById("clear-" + content_source).style.display = "";
} else {
document.getElementById("configured-icon-" + content_source).style.display = "none";
document.getElementById("clear-" + content_source).style.display = "none";
}
});
});
}
// Setup the results count slider
@ -362,70 +505,5 @@
}
})
}
// Get all currently indexed files
function getAllFilenames() {
fetch('/api/config/data/all')
.then(response => response.json())
.then(data => {
var indexedFiles = document.getElementsByClassName("indexed-files")[0];
indexedFiles.innerHTML = "";
if (data.length == 0) {
document.getElementById("delete-all-files").style.display = "none";
indexedFiles.innerHTML = "<div>Use the <a href='https://download.khoj.dev'>Khoj Desktop client</a> to index files.</div>";
} else {
document.getElementById("delete-all-files").style.display = "block";
}
for (var filename of data) {
let fileElement = document.createElement("div");
fileElement.classList.add("file-element");
let fileNameElement = document.createElement("div");
fileNameElement.classList.add("content-name");
fileNameElement.innerHTML = filename;
fileElement.appendChild(fileNameElement);
let buttonContainer = document.createElement("div");
buttonContainer.classList.add("remove-button-container");
let removeFileButton = document.createElement("button");
removeFileButton.classList.add("remove-file-button");
removeFileButton.innerHTML = "🗑️";
removeFileButton.addEventListener("click", ((filename) => {
return () => {
removeFile(filename);
};
})(filename));
buttonContainer.appendChild(removeFileButton);
fileElement.appendChild(buttonContainer);
indexedFiles.appendChild(fileElement);
}
})
.catch((error) => {
console.error('Error:', error);
});
}
// Get all currently indexed files on page load
getAllFilenames();
let deleteAllFilesButton = document.getElementById("delete-all-files");
deleteAllFilesButton.addEventListener("click", function(event) {
event.preventDefault();
fetch('/api/config/data/all', {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
}
})
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
getAllFilenames();
}
})
});
</script>
{% endblock %}

View file

@ -0,0 +1,129 @@
{% extends "base_config.html" %}
{% block content %}
<div class="page">
<div class="section">
<h2 class="section-title">
<img class="card-icon" src="/static/assets/icons/computer.png" alt="files">
<span class="card-title-text">Files</span>
<div class="instructions">
<p class="card-description">Manage files from your computer</p>
<p id="desktop-client" class="card-description">Download the <a href="https://download.khoj.dev">Khoj Desktop app</a> to sync files from your computer</p>
</div>
</h2>
<div class="section-manage-files">
<div id="delete-all-files" class="delete-all-files">
<button id="delete-all-files" type="submit" title="Remove all computer files from Khoj">🗑️ Delete all</button>
</div>
<div class="indexed-files">
</div>
</div>
</div>
</div>
<style>
#desktop-client {
font-weight: normal;
}
.indexed-files {
width: 100%;
}
.content-name {
font-size: smaller;
}
</style>
<script>
function removeFile(path) {
fetch('/api/config/data/file?filename=' + path, {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
}
})
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => {
if (data.status == "ok") {
getAllComputerFilenames();
}
})
}
// Get all currently indexed files
function getAllComputerFilenames() {
fetch('/api/config/data/computer')
.then(response => response.json())
.then(data => {
var indexedFiles = document.getElementsByClassName("indexed-files")[0];
indexedFiles.innerHTML = "";
if (data.length == 0) {
document.getElementById("delete-all-files").style.display = "none";
indexedFiles.innerHTML = "<div class='card-description'>Use the <a href='https://download.khoj.dev'>Khoj Desktop client</a> to index files.</div>";
} else {
document.getElementById("delete-all-files").style.display = "block";
}
for (var filename of data) {
let fileElement = document.createElement("div");
fileElement.classList.add("file-element");
let fileExtension = filename.split('.').pop();
if (fileExtension === "org")
image_name = "org.svg"
else if (fileExtension === "pdf")
image_name = "pdf.svg"
else if (fileExtension === "markdown" || fileExtension === "md")
image_name = "markdown.svg"
else
image_name = "plaintext.svg"
let fileIconElement = document.createElement("img");
fileIconElement.classList.add("card-icon");
fileIconElement.src = `/static/assets/icons/${image_name}`;
fileIconElement.alt = "File";
fileElement.appendChild(fileIconElement);
let fileNameElement = document.createElement("div");
fileNameElement.classList.add("content-name");
fileNameElement.innerHTML = filename;
fileElement.appendChild(fileNameElement);
let buttonContainer = document.createElement("div");
buttonContainer.classList.add("remove-button-container");
let removeFileButton = document.createElement("button");
removeFileButton.classList.add("remove-file-button");
removeFileButton.innerHTML = "🗑️";
removeFileButton.addEventListener("click", ((filename) => {
return () => {
removeFile(filename);
};
})(filename));
buttonContainer.appendChild(removeFileButton);
fileElement.appendChild(buttonContainer);
indexedFiles.appendChild(fileElement);
}
})
.catch((error) => {
console.error('Error:', error);
});
}
// Get all currently indexed files on page load
getAllComputerFilenames();
let deleteAllComputerFilesButton = document.getElementById("delete-all-files");
deleteAllComputerFilesButton.addEventListener("click", function(event) {
event.preventDefault();
fetch('/api/config/data/content-source/computer', {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
}
})
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
getAllComputerFilenames();
}
})
});
</script>
{% endblock %}

View file

@ -125,7 +125,7 @@
}
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/config/data/content_type/github', {
fetch('/api/config/data/content-source/github', {
method: 'POST',
headers: {
'Content-Type': 'application/json',

View file

@ -42,7 +42,7 @@
}
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/config/data/content_type/notion', {
fetch('/api/config/data/content-source/notion', {
method: 'POST',
headers: {
'Content-Type': 'application/json',

View file

@ -1,159 +0,0 @@
{% extends "base_config.html" %}
{% block content %}
<div class="page">
<div class="section">
<h2 class="section-title">
<img class="card-icon" src="/static/assets/icons/{{ content_type }}.svg" alt="{{ content_type|capitalize }}">
<span class="card-title-text">{{ content_type|capitalize }}</span>
</h2>
<form id="config-form">
<table>
<tr>
<td>
<label for="input-files" title="Add a {{content_type}} file for Khoj to index">Files</label>
</td>
<td id="input-files-cell">
{% if current_config['input_files'] is none %}
<input type="text" id="input-files" name="input-files" placeholder="~\Documents\notes.{{content_type}}">
{% else %}
{% for input_file in current_config['input_files'] %}
<input type="text" id="input-files" name="input-files" value="{{ input_file }}" placeholder="~\Documents\notes.{{content_type}}">
{% endfor %}
{% endif %}
</td>
<td>
<button type="button" id="input-files-button">Add</button>
</td>
</tr>
<tr>
<td>
<label for="input-filter" title="Add a folder with {{content_type}} files for Khoj to index">Folders</label>
</td>
<td id="input-filter-cell">
{% if current_config['input_filter'] is none %}
<input type="text" id="input-filter" name="input-filter" placeholder="~/Documents/{{content_type}}">
{% else %}
{% for input_filter in current_config['input_filter'] %}
<input type="text" id="input-filter" name="input-filter" placeholder="~/Documents/{{content_type}}" value="{{ input_filter }}">
{% endfor %}
{% endif %}
</td>
<td>
<button type="button" id="input-filter-button">Add</button>
</td>
</tr>
</table>
<div class="section">
<div id="success" style="display: none;" ></div>
<button id="submit" type="submit">Save</button>
</div>
</form>
</div>
</div>
<script>
function addButtonEventListener(fieldName) {
var button = document.getElementById(fieldName + "-button");
button.addEventListener("click", function(event) {
var cell = document.getElementById(fieldName + "-cell");
var newInput = document.createElement("input");
newInput.setAttribute("type", "text");
newInput.setAttribute("name", fieldName);
cell.appendChild(newInput);
})
}
addButtonEventListener("input-files");
addButtonEventListener("input-filter");
function getValidInputNodes(nodes) {
var validNodes = [];
for (var i = 0; i < nodes.length; i++) {
const nodeValue = nodes[i].value;
if (nodeValue === "" || nodeValue === null || nodeValue === undefined || nodeValue === "None") {
continue;
}
validNodes.push(nodes[i]);
}
return validNodes;
}
submit.addEventListener("click", function(event) {
event.preventDefault();
let globFormat = "**/*"
let suffixes = [];
if ('{{content_type}}' == "markdown")
suffixes = [".md", ".markdown"]
else if ('{{content_type}}' == "org")
suffixes = [".org"]
else if ('{{content_type}}' === "pdf")
suffixes = [".pdf"]
else if ('{{content_type}}' === "plaintext")
suffixes = ['.*']
let globs = suffixes.map(x => `${globFormat}${x}`)
var inputFileNodes = document.getElementsByName("input-files");
var inputFiles = getValidInputNodes(inputFileNodes).map(node => node.value);
var inputFilterNodes = document.getElementsByName("input-filter");
var inputFilter = [];
var nodes = getValidInputNodes(inputFilterNodes);
// A regex that checks for globs in the path. If they exist,
// we are going to just not add our own globing. If they don't,
// then we will assume globbing should be done.
const glob_regex = /([*?\[\]])/;
if (nodes.length > 0) {
for (var i = 0; i < nodes.length; i++) {
for (var j = 0; j < globs.length; j++) {
if (glob_regex.test(nodes[i].value)) {
inputFilter.push(nodes[i].value);
} else {
inputFilter.push(nodes[i].value + globs[j]);
}
}
}
}
if (inputFiles.length === 0 && inputFilter.length === 0) {
alert("You must specify at least one input file or input filter.");
return;
}
if (inputFiles.length == 0) {
inputFiles = null;
}
if (inputFilter.length == 0) {
inputFilter = null;
}
// var index_heading_entries = document.getElementById("index-heading-entries").value;
var index_heading_entries = true;
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/config/data/content_type/{{ content_type }}', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-CSRFToken': csrfToken
},
body: JSON.stringify({
"input_files": inputFiles,
"input_filter": inputFilter,
"index_heading_entries": index_heading_entries
})
})
.then(response => response.json())
.then(data => {
if (data["status"] == "ok") {
document.getElementById("success").innerHTML = "✅ Successfully updated. Go to your <a href='/config'>settings page</a> to complete setup.";
document.getElementById("success").style.display = "block";
} else {
document.getElementById("success").innerHTML = "⚠️ Failed to update settings.";
document.getElementById("success").style.display = "block";
}
})
});
</script>
{% endblock %}

View file

@ -104,7 +104,12 @@ class GithubToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, DbEntry.EntryType.GITHUB, key="compiled", logger=logger, user=user
current_entries,
DbEntry.EntryType.GITHUB,
DbEntry.EntrySource.GITHUB,
key="compiled",
logger=logger,
user=user,
)
return num_new_embeddings, num_deleted_embeddings

View file

@ -47,6 +47,7 @@ class MarkdownToEntries(TextToEntries):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.MARKDOWN,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,

View file

@ -250,7 +250,12 @@ class NotionToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, DbEntry.EntryType.NOTION, key="compiled", logger=logger, user=user
current_entries,
DbEntry.EntryType.NOTION,
DbEntry.EntrySource.NOTION,
key="compiled",
logger=logger,
user=user,
)
return num_new_embeddings, num_deleted_embeddings

View file

@ -48,6 +48,7 @@ class OrgToEntries(TextToEntries):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.ORG,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,

View file

@ -46,6 +46,7 @@ class PdfToEntries(TextToEntries):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.PDF,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,

View file

@ -56,6 +56,7 @@ class PlaintextToEntries(TextToEntries):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.PLAINTEXT,
DbEntry.EntrySource.COMPUTER,
key="compiled",
logger=logger,
deletion_filenames=deletion_file_names,

View file

@ -78,6 +78,7 @@ class TextToEntries(ABC):
self,
current_entries: List[Entry],
file_type: str,
file_source: str,
key="compiled",
logger: logging.Logger = None,
deletion_filenames: Set[str] = None,
@ -93,9 +94,9 @@ class TextToEntries(ABC):
num_deleted_entries = 0
if regenerate:
with timer("Prepared dataset for regeneration in", logger):
with timer("Cleared existing dataset for regeneration in", logger):
logger.debug(f"Deleting all entries for file type {file_type}")
num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type)
num_deleted_entries = EntryAdapters.delete_all_entries_by_type(user, file_type)
hashes_to_process = set()
with timer("Identified entries to add to database in", logger):
@ -132,6 +133,7 @@ class TextToEntries(ABC):
compiled=entry.compiled,
heading=entry.heading[:1000], # Truncate to max chars of field allowed
file_path=entry.file,
file_source=file_source,
file_type=file_type,
hashed_value=entry_hash,
corpus_id=entry.corpus_id,

View file

@ -23,7 +23,6 @@ from khoj.utils.rawconfig import (
FullConfig,
SearchConfig,
SearchResponse,
TextContentConfig,
GithubContentConfig,
NotionContentConfig,
)
@ -51,6 +50,7 @@ from database.models import (
LocalPdfConfig,
LocalPlaintextConfig,
KhojUser,
Entry as DbEntry,
GithubConfig,
NotionConfig,
)
@ -61,11 +61,13 @@ api = APIRouter()
logger = logging.getLogger(__name__)
def map_config_to_object(content_type: str):
if content_type == "github":
def map_config_to_object(content_source: str):
if content_source == DbEntry.EntrySource.GITHUB:
return GithubConfig
if content_type == "notion":
if content_source == DbEntry.EntrySource.GITHUB:
return NotionConfig
if content_source == DbEntry.EntrySource.COMPUTER:
return "Computer"
async def map_config_to_db(config: FullConfig, user: KhojUser):
@ -164,7 +166,7 @@ async def set_config_data(
return state.config
@api.post("/config/data/content_type/github", status_code=200)
@api.post("/config/data/content-source/github", status_code=200)
@requires(["authenticated"])
async def set_content_config_github_data(
request: Request,
@ -192,7 +194,7 @@ async def set_content_config_github_data(
return {"status": "ok"}
@api.post("/config/data/content_type/notion", status_code=200)
@api.post("/config/data/content-source/notion", status_code=200)
@requires(["authenticated"])
async def set_content_config_notion_data(
request: Request,
@ -219,11 +221,11 @@ async def set_content_config_notion_data(
return {"status": "ok"}
@api.delete("/config/data/content_type/{content_type}", status_code=200)
@api.delete("/config/data/content-source/{content_source}", status_code=200)
@requires(["authenticated"])
async def remove_content_config_data(
async def remove_content_source_data(
request: Request,
content_type: str,
content_source: str,
client: Optional[str] = None,
):
user = request.user.object
@ -233,15 +235,15 @@ async def remove_content_config_data(
telemetry_type="api",
api="delete_content_config",
client=client,
metadata={"content_type": content_type},
metadata={"content_source": content_source},
)
content_object = map_config_to_object(content_type)
content_object = map_config_to_object(content_source)
if content_object is None:
raise ValueError(f"Invalid content type: {content_type}")
await content_object.objects.filter(user=user).adelete()
await sync_to_async(EntryAdapters.delete_all_entries)(user, content_type)
raise ValueError(f"Invalid content source: {content_source}")
elif content_object != "Computer":
await content_object.objects.filter(user=user).adelete()
await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source)
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
return {"status": "ok"}
@ -268,10 +270,11 @@ async def remove_file_data(
return {"status": "ok"}
@api.get("/config/data/all", response_model=List[str])
@api.get("/config/data/{content_source}", response_model=List[str])
@requires(["authenticated"])
async def get_all_filenames(
request: Request,
content_source: str,
client: Optional[str] = None,
):
user = request.user.object
@ -283,27 +286,7 @@ async def get_all_filenames(
client=client,
)
return await sync_to_async(list)(EntryAdapters.aget_all_filenames(user))
@api.delete("/config/data/all", status_code=200)
@requires(["authenticated"])
async def remove_all_config_data(
request: Request,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="delete_all_config",
client=client,
)
await EntryAdapters.adelete_all_entries(user)
return {"status": "ok"}
return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source))
@api.post("/config/data/conversation/model", status_code=200)

View file

@ -24,7 +24,9 @@ logger = logging.getLogger(__name__)
auth_router = APIRouter()
if not state.anonymous_mode and not (os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET")):
logger.info("Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET environment variables to use Google OAuth")
logger.warn(
"🚨 Use --anonymous-mode flag to disable Google OAuth or set GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET environment variables to enable it"
)
else:
config = Config(environ=os.environ)

View file

@ -126,7 +126,7 @@ async def update(
# Extract required fields from config
loop = asyncio.get_event_loop()
state.content_index = await loop.run_in_executor(
state.content_index, success = await loop.run_in_executor(
None,
configure_content,
state.content_index,
@ -138,6 +138,8 @@ async def update(
False,
user,
)
if not success:
raise RuntimeError("Failed to update content index")
logger.info(f"Finished processing batch indexing request")
except Exception as e:
logger.error(f"Failed to process batch indexing request: {e}", exc_info=True)
@ -145,6 +147,7 @@ async def update(
f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}",
exc_info=True,
)
return Response(content="Failed", status_code=500)
update_telemetry_state(
request=request,
@ -182,18 +185,19 @@ def configure_content(
t: Optional[state.SearchType] = None,
full_corpus: bool = True,
user: KhojUser = None,
) -> Optional[ContentIndex]:
) -> tuple[Optional[ContentIndex], bool]:
content_index = ContentIndex()
success = True
if t is not None and not t.value in [type.value for type in state.SearchType]:
logger.warning(f"🚨 Invalid search type: {t}")
return None
return None, False
search_type = t.value if t else None
if files is None:
logger.warning(f"🚨 No files to process for {search_type} search.")
return None
return None, True
try:
# Initialize Org Notes Search
@ -209,6 +213,7 @@ def configure_content(
)
except Exception as e:
logger.error(f"🚨 Failed to setup org: {e}", exc_info=True)
success = False
try:
# Initialize Markdown Search
@ -225,6 +230,7 @@ def configure_content(
except Exception as e:
logger.error(f"🚨 Failed to setup markdown: {e}", exc_info=True)
success = False
try:
# Initialize PDF Search
@ -241,6 +247,7 @@ def configure_content(
except Exception as e:
logger.error(f"🚨 Failed to setup PDF: {e}", exc_info=True)
success = False
try:
# Initialize Plaintext Search
@ -257,6 +264,7 @@ def configure_content(
except Exception as e:
logger.error(f"🚨 Failed to setup plaintext: {e}", exc_info=True)
success = False
try:
# Initialize Image Search
@ -274,6 +282,7 @@ def configure_content(
except Exception as e:
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
success = False
try:
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
@ -291,6 +300,7 @@ def configure_content(
except Exception as e:
logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True)
success = False
try:
# Initialize Notion Search
@ -308,12 +318,13 @@ def configure_content(
except Exception as e:
logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True)
success = False
# Invalidate Query Cache
if user:
state.query_cache[user.uuid] = LRU()
return content_index
return content_index, success
def load_content(

View file

@ -0,0 +1,106 @@
# Standard Packages
from datetime import datetime, timezone
import logging
import os
# External Packages
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Request
from starlette.authentication import requires
import stripe
# Internal Packages
from database import adapters
# Stripe integration for Khoj Cloud Subscription
stripe.api_key = os.getenv("STRIPE_API_KEY")
endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET")
logger = logging.getLogger(__name__)
subscription_router = APIRouter()
@subscription_router.post("")
async def subscribe(request: Request):
"""Webhook for Stripe to send subscription events to Khoj Cloud"""
event = None
try:
payload = await request.body()
sig_header = request.headers["stripe-signature"]
event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret)
except ValueError as e:
# Invalid payload
raise e
except stripe.error.SignatureVerificationError as e:
# Invalid signature
raise e
event_type = event["type"]
if event_type not in {
"invoice.paid",
"customer.subscription.updated",
"customer.subscription.deleted",
"subscription_schedule.canceled",
}:
logger.warn(f"Unhandled Stripe event type: {event['type']}")
return {"success": False}
# Retrieve the customer's details
subscription = event["data"]["object"]
customer_id = subscription["customer"]
customer = stripe.Customer.retrieve(customer_id)
customer_email = customer["email"]
# Handle valid stripe webhook events
success = True
if event_type in {"invoice.paid"}:
# Mark the user as subscribed and update the next renewal date on payment
subscription = stripe.Subscription.list(customer=customer_id).data[0]
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
user = await adapters.set_user_subscription(customer_email, is_recurring=True, renewal_date=renewal_date)
success = user is not None
elif event_type in {"customer.subscription.updated"}:
user_subscription = await sync_to_async(adapters.get_user_subscription)(customer_email)
# Allow updating subscription status if paid user
if user_subscription and user_subscription.renewal_date:
# Mark user as unsubscribed or resubscribed
is_recurring = not subscription["cancel_at_period_end"]
updated_user = await adapters.set_user_subscription(customer_email, is_recurring=is_recurring)
success = updated_user is not None
elif event_type in {"customer.subscription.deleted"}:
# Reset the user to trial state
user = await adapters.set_user_subscription(
customer_email, is_recurring=False, renewal_date=False, type="trial"
)
success = user is not None
logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
return {"success": success}
@subscription_router.patch("")
@requires(["authenticated"])
async def update_subscription(request: Request, email: str, operation: str):
# Retrieve the customer's details
customers = stripe.Customer.list(email=email).auto_paging_iter()
customer = next(customers, None)
if customer is None:
return {"success": False, "message": "Customer not found"}
if operation == "cancel":
customer_id = customer.id
for subscription in stripe.Subscription.list(customer=customer_id):
stripe.Subscription.modify(subscription.id, cancel_at_period_end=True)
return {"success": True}
elif operation == "resubscribe":
subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter()
# Find the subscription that is set to cancel at the end of the period
for subscription in subscriptions:
if subscription.cancel_at_period_end:
# Update the subscription to not cancel at the end of the period
stripe.Subscription.modify(subscription.id, cancel_at_period_end=False)
return {"success": True}
return {"success": False, "message": "No subscription found that is set to cancel"}
return {"success": False, "message": "Invalid operation"}

View file

@ -8,8 +8,9 @@ from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from starlette.authentication import requires
from database import adapters
from database.models import KhojUser
from khoj.utils.rawconfig import (
TextContentConfig,
GithubContentConfig,
GithubRepoConfig,
NotionContentConfig,
@ -17,15 +18,18 @@ from khoj.utils.rawconfig import (
# Internal Packages
from khoj.utils import constants, state
from database.adapters import EntryAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters
from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
from database.adapters import (
EntryAdapters,
get_user_github_config,
get_user_notion_config,
ConversationAdapters,
get_user_subscription_state,
)
# Initialize Router
web_client = APIRouter()
templates = Jinja2Templates(directory=constants.web_directory)
VALID_TEXT_CONTENT_TYPES = ["org", "markdown", "pdf", "plaintext"]
# Create Routes
@web_client.get("/", response_class=FileResponse)
@ -109,41 +113,26 @@ def login_page(request: Request):
)
def map_config_to_object(content_type: str):
if content_type == "org":
return LocalOrgConfig
if content_type == "markdown":
return LocalMarkdownConfig
if content_type == "pdf":
return LocalPdfConfig
if content_type == "plaintext":
return LocalPlaintextConfig
@web_client.get("/config", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def config_page(request: Request):
user = request.user.object
user: KhojUser = request.user.object
user_picture = request.session.get("user", {}).get("picture")
enabled_content = set(EntryAdapters.get_unique_file_types(user).all())
user_subscription = adapters.get_user_subscription(user.email)
user_subscription_state = get_user_subscription_state(user_subscription)
subscription_renewal_date = (
user_subscription.renewal_date.strftime("%d %b %Y")
if user_subscription and user_subscription.renewal_date
else None
)
enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all())
successfully_configured = {
"pdf": ("pdf" in enabled_content),
"markdown": ("markdown" in enabled_content),
"org": ("org" in enabled_content),
"image": False,
"github": ("github" in enabled_content),
"notion": ("notion" in enabled_content),
"plaintext": ("plaintext" in enabled_content),
"computer": ("computer" in enabled_content_source),
"github": ("github" in enabled_content_source),
"notion": ("notion" in enabled_content_source),
}
if state.content_index:
successfully_configured.update(
{
"image": state.content_index.image is not None,
}
)
conversation_options = ConversationAdapters.get_conversation_processor_options().all()
all_conversation_options = list()
for conversation_option in conversation_options:
@ -157,15 +146,19 @@ def config_page(request: Request):
"request": request,
"current_model_state": successfully_configured,
"anonymous_mode": state.anonymous_mode,
"username": user.username if user else None,
"username": user.username,
"conversation_options": all_conversation_options,
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
"user_photo": user_picture,
"billing_enabled": state.billing_enabled,
"subscription_state": user_subscription_state,
"subscription_renewal_date": subscription_renewal_date,
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
},
)
@web_client.get("/config/content_type/github", response_class=HTMLResponse)
@web_client.get("/config/content-source/github", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def github_config_page(request: Request):
user = request.user.object
@ -192,7 +185,7 @@ def github_config_page(request: Request):
current_config = {} # type: ignore
return templates.TemplateResponse(
"content_type_github_input.html",
"content_source_github_input.html",
context={
"request": request,
"current_config": current_config,
@ -202,7 +195,7 @@ def github_config_page(request: Request):
)
@web_client.get("/config/content_type/notion", response_class=HTMLResponse)
@web_client.get("/config/content-source/notion", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def notion_config_page(request: Request):
user = request.user.object
@ -216,7 +209,7 @@ def notion_config_page(request: Request):
current_config = json.loads(current_config.json())
return templates.TemplateResponse(
"content_type_notion_input.html",
"content_source_notion_input.html",
context={
"request": request,
"current_config": current_config,
@ -226,32 +219,16 @@ def notion_config_page(request: Request):
)
@web_client.get("/config/content_type/{content_type}", response_class=HTMLResponse)
@web_client.get("/config/content-source/computer", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def content_config_page(request: Request, content_type: str):
if content_type not in VALID_TEXT_CONTENT_TYPES:
return templates.TemplateResponse("config.html", context={"request": request})
object = map_config_to_object(content_type)
def computer_config_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
config = object.objects.filter(user=user).first()
if config == None:
config = object.objects.create(user=user)
current_config = TextContentConfig(
input_files=config.input_files,
input_filter=config.input_filter,
index_heading_entries=config.index_heading_entries,
)
current_config = json.loads(current_config.json())
return templates.TemplateResponse(
"content_type_input.html",
"content_source_computer_input.html",
context={
"request": request,
"current_config": current_config,
"content_type": content_type,
"username": user.username,
"user_photo": user_picture,
},

View file

@ -204,11 +204,12 @@ def setup(
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
)
file_names = [file_name for file_name in files]
if files:
file_names = [file_name for file_name in files]
logger.info(
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
)
logger.info(
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
)
def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]:

View file

@ -1,4 +1,5 @@
# Standard Packages
import os
import threading
from typing import List, Dict
from collections import defaultdict
@ -35,3 +36,8 @@ khoj_version: str = None
device = get_device()
chat_on_gpu: bool = True
anonymous_mode: bool = False
billing_enabled: bool = (
os.getenv("STRIPE_API_KEY") is not None
and os.getenv("STRIPE_SIGNING_SECRET") is not None
and os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL") is not None
)

View file

@ -196,7 +196,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser):
# Index Markdown Content for Search
all_files = fs_syncer.collect_files(user=default_user2)
state.content_index = configure_content(
state.content_index, _ = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
)

View file

@ -64,6 +64,7 @@ def test_encode_docs_memory_leak():
batch_size = 20
embeddings_model = EmbeddingsModel()
memory_usage_trend = []
device = f"{helpers.get_device()}".upper()
# Act
# Encode random strings repeatedly and record memory usage trend
@ -76,8 +77,9 @@ def test_encode_docs_memory_leak():
# Calculate slope of line fitting memory usage history
memory_usage_trend = np.array(memory_usage_trend)
slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend)
print(f"Memory usage increased at ~{slope:.2f} MB per iteration on {device}")
# Assert
# If slope is positive memory utilization is increasing
# Positive threshold of 2, from observing memory usage trend on MPS vs CPU device
assert slope < 2, f"Memory usage increasing at ~{slope:.2f} MB per iteration"
assert slope < 2, f"Memory leak suspected on {device}. Memory usage increased at ~{slope:.2f} MB per iteration"

View file

@ -58,7 +58,7 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_search_setup_with_empty_file_raises_error(
def test_text_search_setup_with_empty_file_creates_no_entries(
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
):
# Arrange