mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-28 01:45:07 +01:00
Create Billing integration. Improve Settings pages on Desktop, Web apps (#537)
### Major - Expose Billing via Stripe on Khoj Web app for Khoj Cloud subscription - Expose card on web app config page to manage subscription to Khoj cloud - Create API webhook, endpoints for subscription payments using Stripe - Put Computer files to index into Card under Content section - Show file type icons for each indexed file in config card of web app - Enable deleting all indexed desktop files from Khoj via Desktop app - Create config page on web app to manage computer files indexed by Khoj - Track data source (computer, github, notion) of each entry - Update content by source via API. Make web client use this API for config - Store the data source of each entry in database ### Cleanup - Set content enabled status on update via config buttons on web app - Delete deprecated content config pages for local files from web client - Rename Sync button, Force Sync toggle to Save, Save All buttons ### Fixes - Prevent Desktop app triggering multiple simultaneous syncs to server - Upgrade langchain version since adding support for OCR-ing PDFs - Bubble up content indexing errors to notify user on client apps
This commit is contained in:
commit
1d3bdf8fdb
40 changed files with 758 additions and 414 deletions
|
@ -9,6 +9,6 @@ The Github integration allows you to index as many repositories as you want. It'
|
|||
## Use the Github plugin
|
||||
|
||||
1. Generate a [classic PAT (personal access token)](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) from [Github](https://github.com/settings/tokens) with `repo` and `admin:org` scopes at least.
|
||||
2. Navigate to [http://localhost:42110/config/content_type/github](http://localhost:42110/config/content_type/github) to configure your Github settings. Enter in your PAT, along with details for each repository you want to index.
|
||||
2. Navigate to [http://localhost:42110/config/content-source/github](http://localhost:42110/config/content-source/github) to configure your Github settings. Enter in your PAT, along with details for each repository you want to index.
|
||||
3. Click `Save`. Go back to the settings page and click `Configure`.
|
||||
4. Go to [http://localhost:42110/](http://localhost:42110/) and start searching!
|
||||
|
|
|
@ -8,7 +8,7 @@ We haven't setup a fancy integration with OAuth yet, so this integration still r
|
|||
![setup_new_integration](https://github.com/khoj-ai/khoj/assets/65192171/b056e057-d4dc-47dc-aad3-57b59a22c68b)
|
||||
3. Share all the workspaces that you want to integrate with the Khoj integration you just made in the previous step
|
||||
![enable_workspace](https://github.com/khoj-ai/khoj/assets/65192171/98290303-b5b8-4cb0-b32c-f68c6923a3d0)
|
||||
4. In the first step, you generated an API key. Use the newly generated API Key in your Khoj settings, by default at http://localhost:42110/config/content_type/notion. Click `Save`.
|
||||
4. In the first step, you generated an API key. Use the newly generated API Key in your Khoj settings, by default at http://localhost:42110/config/content-source/notion. Click `Save`.
|
||||
5. Click `Configure` in http://localhost:42110/config to index your Notion workspace(s).
|
||||
|
||||
That's it! You should be ready to start searching and chatting. Make sure you've configured your OpenAI API Key for chat.
|
||||
|
|
|
@ -55,7 +55,7 @@ dependencies = [
|
|||
"torch == 2.0.1",
|
||||
"uvicorn == 0.17.6",
|
||||
"aiohttp == 3.8.5",
|
||||
"langchain >= 0.0.187",
|
||||
"langchain >= 0.0.331",
|
||||
"requests >= 2.26.0",
|
||||
"bs4 >= 0.0.1",
|
||||
"anyio == 3.7.1",
|
||||
|
@ -73,7 +73,8 @@ dependencies = [
|
|||
"gunicorn == 21.2.0",
|
||||
"lxml == 4.9.3",
|
||||
"tzdata == 2023.3",
|
||||
"rapidocr-onnxruntime == 1.3.8"
|
||||
"rapidocr-onnxruntime == 1.3.8",
|
||||
"stripe == 7.3.0",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
|||
# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/
|
||||
|
||||
# SECURITY WARNING: keep the secret key used in production secret!
|
||||
SECRET_KEY = os.getenv("DJANGO_SECRET_KEY")
|
||||
SECRET_KEY = os.getenv("KHOJ_DJANGO_SECRET_KEY")
|
||||
|
||||
# SECURITY WARNING: don't run with debug turned on in production!
|
||||
DEBUG = os.getenv("DJANGO_DEBUG", "False") == "True"
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Type, TypeVar, List
|
||||
from datetime import date
|
||||
from typing import Optional, Type, TypeVar, List
|
||||
from datetime import date, datetime, timedelta
|
||||
import secrets
|
||||
from typing import Type, TypeVar, List
|
||||
from datetime import date
|
||||
from datetime import date, timezone
|
||||
|
||||
from django.db import models
|
||||
from django.contrib.sessions.backends.db import SessionStore
|
||||
|
@ -30,6 +30,7 @@ from database.models import (
|
|||
GithubRepoConfig,
|
||||
Conversation,
|
||||
ChatModelOptions,
|
||||
Subscription,
|
||||
UserConversationConfig,
|
||||
OpenAIProcessorConversationConfig,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
|
@ -103,6 +104,57 @@ async def create_google_user(token: dict) -> KhojUser:
|
|||
return user
|
||||
|
||||
|
||||
def get_user_subscription(email: str) -> Optional[Subscription]:
|
||||
return Subscription.objects.filter(user__email=email).first()
|
||||
|
||||
|
||||
async def set_user_subscription(
|
||||
email: str, is_recurring=None, renewal_date=None, type="standard"
|
||||
) -> Optional[Subscription]:
|
||||
user_subscription = await Subscription.objects.filter(user__email=email).afirst()
|
||||
if not user_subscription:
|
||||
user = await get_user_by_email(email)
|
||||
if not user:
|
||||
return None
|
||||
user_subscription = await Subscription.objects.acreate(
|
||||
user=user, type=type, is_recurring=is_recurring, renewal_date=renewal_date
|
||||
)
|
||||
return user_subscription
|
||||
elif user_subscription:
|
||||
user_subscription.type = type
|
||||
if is_recurring is not None:
|
||||
user_subscription.is_recurring = is_recurring
|
||||
if renewal_date is False:
|
||||
user_subscription.renewal_date = None
|
||||
elif renewal_date is not None:
|
||||
user_subscription.renewal_date = renewal_date
|
||||
await user_subscription.asave()
|
||||
return user_subscription
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_user_subscription_state(user_subscription: Subscription) -> str:
|
||||
"""Get subscription state of user
|
||||
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
|
||||
"""
|
||||
if not user_subscription:
|
||||
return "trial"
|
||||
elif user_subscription.type == Subscription.Type.TRIAL:
|
||||
return "trial"
|
||||
elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||
return "subscribed"
|
||||
elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||
return "unsubscribed"
|
||||
elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc):
|
||||
return "expired"
|
||||
return "invalid"
|
||||
|
||||
|
||||
async def get_user_by_email(email: str) -> KhojUser:
|
||||
return await KhojUser.objects.filter(email=email).afirst()
|
||||
|
||||
|
||||
async def get_user_by_token(token: dict) -> KhojUser:
|
||||
google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
|
||||
if not google_user:
|
||||
|
@ -287,13 +339,21 @@ class EntryAdapters:
|
|||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
def delete_all_entries(user: KhojUser, file_type: str = None):
|
||||
def delete_all_entries_by_type(user: KhojUser, file_type: str = None):
|
||||
if file_type is None:
|
||||
deleted_count, _ = Entry.objects.filter(user=user).delete()
|
||||
else:
|
||||
deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete()
|
||||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
def delete_all_entries(user: KhojUser, file_source: str = None):
|
||||
if file_source is None:
|
||||
deleted_count, _ = Entry.objects.filter(user=user).delete()
|
||||
else:
|
||||
deleted_count, _ = Entry.objects.filter(user=user, file_source=file_source).delete()
|
||||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
|
||||
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
|
||||
|
@ -318,8 +378,12 @@ class EntryAdapters:
|
|||
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
|
||||
|
||||
@staticmethod
|
||||
def aget_all_filenames(user: KhojUser):
|
||||
return Entry.objects.filter(user=user).distinct("file_path").values_list("file_path", flat=True)
|
||||
def aget_all_filenames_by_source(user: KhojUser, file_source: str):
|
||||
return (
|
||||
Entry.objects.filter(user=user, file_source=file_source)
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def adelete_all_entries(user: KhojUser):
|
||||
|
@ -384,3 +448,7 @@ class EntryAdapters:
|
|||
@staticmethod
|
||||
def get_unique_file_types(user: KhojUser):
|
||||
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
|
||||
|
||||
@staticmethod
|
||||
def get_unique_file_source(user: KhojUser):
|
||||
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct()
|
||||
|
|
21
src/database/migrations/0012_entry_file_source.py
Normal file
21
src/database/migrations/0012_entry_file_source.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Generated by Django 4.2.5 on 2023-11-07 07:24
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0011_merge_20231102_0138"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="entry",
|
||||
name="file_source",
|
||||
field=models.CharField(
|
||||
choices=[("computer", "Computer"), ("notion", "Notion"), ("github", "Github")],
|
||||
default="computer",
|
||||
max_length=30,
|
||||
),
|
||||
),
|
||||
]
|
37
src/database/migrations/0013_subscription.py
Normal file
37
src/database/migrations/0013_subscription.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
# Generated by Django 4.2.5 on 2023-11-09 01:27
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0012_entry_file_source"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="Subscription",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
(
|
||||
"type",
|
||||
models.CharField(
|
||||
choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20
|
||||
),
|
||||
),
|
||||
("is_recurring", models.BooleanField(default=False)),
|
||||
("renewal_date", models.DateTimeField(default=None, null=True)),
|
||||
(
|
||||
"user",
|
||||
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
|
@ -46,6 +46,17 @@ class KhojApiUser(models.Model):
|
|||
accessed_at = models.DateTimeField(null=True, default=None)
|
||||
|
||||
|
||||
class Subscription(BaseModel):
|
||||
class Type(models.TextChoices):
|
||||
TRIAL = "trial"
|
||||
STANDARD = "standard"
|
||||
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
|
||||
is_recurring = models.BooleanField(default=False)
|
||||
renewal_date = models.DateTimeField(null=True, default=None)
|
||||
|
||||
|
||||
class NotionConfig(BaseModel):
|
||||
token = models.CharField(max_length=200)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
@ -131,11 +142,17 @@ class Entry(BaseModel):
|
|||
GITHUB = "github"
|
||||
CONVERSATION = "conversation"
|
||||
|
||||
class EntrySource(models.TextChoices):
|
||||
COMPUTER = "computer"
|
||||
NOTION = "notion"
|
||||
GITHUB = "github"
|
||||
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
embeddings = VectorField(dimensions=384)
|
||||
raw = models.TextField()
|
||||
compiled = models.TextField()
|
||||
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
|
||||
file_source = models.CharField(max_length=30, choices=EntrySource.choices, default=EntrySource.COMPUTER)
|
||||
file_type = models.CharField(max_length=30, choices=EntryType.choices, default=EntryType.PLAINTEXT)
|
||||
file_path = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||
|
|
|
@ -192,9 +192,9 @@
|
|||
.then(response => {
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let references = null;
|
||||
|
||||
function readStream() {
|
||||
let references = null;
|
||||
reader.read().then(({ done, value }) => {
|
||||
if (done) {
|
||||
// Evaluate the contents of new_response_text.innerHTML after all the data has been streamed
|
||||
|
|
|
@ -91,11 +91,13 @@
|
|||
</div>
|
||||
<div class="section-action-row">
|
||||
<div class="card-description-row">
|
||||
<button id="sync-data">Sync</button>
|
||||
<button id="sync-data" class="sync-data">💾 Save</button>
|
||||
</div>
|
||||
<div class="card-description-row sync-force-toggle">
|
||||
<input id="sync-force" type="checkbox" name="sync-force" value="force">
|
||||
<label for="sync-force">Force Sync</label>
|
||||
<div class="card-description-row">
|
||||
<button id="sync-force" class="sync-data">💾 Save All</button>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<button id="delete-all" class="sync-data">🗑️ Delete All</button>
|
||||
</div>
|
||||
</div>
|
||||
<div id="loading-bar" style="display: none;"></div>
|
||||
|
@ -336,7 +338,7 @@
|
|||
padding: 4px;
|
||||
cursor: pointer;
|
||||
}
|
||||
#sync-data {
|
||||
button.sync-data {
|
||||
background-color: var(--primary);
|
||||
border: none;
|
||||
color: var(--main-text-color);
|
||||
|
@ -351,7 +353,7 @@
|
|||
box-shadow: 0px 5px 0px var(--background-color);
|
||||
}
|
||||
|
||||
#sync-data:hover {
|
||||
button.sync-data:hover {
|
||||
background-color: var(--primary-hover);
|
||||
box-shadow: 0px 3px 0px var(--background-color);
|
||||
}
|
||||
|
|
|
@ -67,6 +67,7 @@ const schema = {
|
|||
}
|
||||
};
|
||||
|
||||
const syncing = false;
|
||||
var state = {}
|
||||
const store = new Store({ schema });
|
||||
|
||||
|
@ -110,6 +111,15 @@ function filenameToMimeType (filename) {
|
|||
}
|
||||
|
||||
function pushDataToKhoj (regenerate = false) {
|
||||
// Don't sync if token or hostURL is not set or if already syncing
|
||||
if (store.get('khojToken') === '' || store.get('hostURL') === '' || syncing === true) {
|
||||
const win = BrowserWindow.getAllWindows()[0];
|
||||
if (win) win.webContents.send('update-state', state);
|
||||
return;
|
||||
} else {
|
||||
syncing = true;
|
||||
}
|
||||
|
||||
let filesToPush = [];
|
||||
const files = store.get('files') || [];
|
||||
const folders = store.get('folders') || [];
|
||||
|
@ -192,11 +202,13 @@ function pushDataToKhoj (regenerate = false) {
|
|||
})
|
||||
.finally(() => {
|
||||
// Syncing complete
|
||||
syncing = false;
|
||||
const win = BrowserWindow.getAllWindows()[0];
|
||||
if (win) win.webContents.send('update-state', state);
|
||||
});
|
||||
} else {
|
||||
// Syncing complete
|
||||
syncing = false;
|
||||
const win = BrowserWindow.getAllWindows()[0];
|
||||
if (win) win.webContents.send('update-state', state);
|
||||
}
|
||||
|
@ -306,6 +318,19 @@ async function syncData (regenerate = false) {
|
|||
}
|
||||
}
|
||||
|
||||
async function deleteAllFiles () {
|
||||
try {
|
||||
store.set('files', []);
|
||||
store.set('folders', []);
|
||||
pushDataToKhoj(true);
|
||||
const date = new Date();
|
||||
console.log('Pushing data to Khoj at: ', date);
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let firstRun = true;
|
||||
let win = null;
|
||||
const createWindow = (tab = 'chat.html') => {
|
||||
|
@ -386,6 +411,7 @@ app.whenReady().then(() => {
|
|||
ipcMain.handle('syncData', (event, regenerate) => {
|
||||
syncData(regenerate);
|
||||
});
|
||||
ipcMain.handle('deleteAllFiles', deleteAllFiles);
|
||||
|
||||
createWindow()
|
||||
|
||||
|
|
|
@ -45,7 +45,8 @@ contextBridge.exposeInMainWorld('hostURLAPI', {
|
|||
})
|
||||
|
||||
contextBridge.exposeInMainWorld('syncDataAPI', {
|
||||
syncData: (regenerate) => ipcRenderer.invoke('syncData', regenerate)
|
||||
syncData: (regenerate) => ipcRenderer.invoke('syncData', regenerate),
|
||||
deleteAllFiles: () => ipcRenderer.invoke('deleteAllFiles')
|
||||
})
|
||||
|
||||
contextBridge.exposeInMainWorld('tokenAPI', {
|
||||
|
|
|
@ -196,9 +196,19 @@ khojKeyInput.addEventListener('blur', async () => {
|
|||
});
|
||||
|
||||
const syncButton = document.getElementById('sync-data');
|
||||
const syncForceToggle = document.getElementById('sync-force');
|
||||
syncButton.addEventListener('click', async () => {
|
||||
loadingBar.style.display = 'block';
|
||||
const regenerate = syncForceToggle.checked;
|
||||
await window.syncDataAPI.syncData(regenerate);
|
||||
await window.syncDataAPI.syncData(false);
|
||||
});
|
||||
|
||||
const syncForceButton = document.getElementById('sync-force');
|
||||
syncForceButton.addEventListener('click', async () => {
|
||||
loadingBar.style.display = 'block';
|
||||
await window.syncDataAPI.syncData(true);
|
||||
});
|
||||
|
||||
const deleteAllButton = document.getElementById('delete-all');
|
||||
deleteAllButton.addEventListener('click', async () => {
|
||||
loadingBar.style.display = 'block';
|
||||
await window.syncDataAPI.deleteAllFiles();
|
||||
});
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Standard Packages
|
||||
import sys
|
||||
import logging
|
||||
import json
|
||||
from enum import Enum
|
||||
|
@ -109,7 +108,6 @@ def configure_server(
|
|||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
initialize_content(regenerate, search_type, init, user)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to configure search models", exc_info=True)
|
||||
raise e
|
||||
finally:
|
||||
state.config_lock.release()
|
||||
|
@ -125,7 +123,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
|
|||
else:
|
||||
logger.info("📬 Updating content index...")
|
||||
all_files = collect_files(user=user)
|
||||
state.content_index = configure_content(
|
||||
state.content_index, status = configure_content(
|
||||
state.content_index,
|
||||
state.config.content_type,
|
||||
all_files,
|
||||
|
@ -134,8 +132,9 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
|
|||
search_type,
|
||||
user=user,
|
||||
)
|
||||
if not status:
|
||||
raise RuntimeError("Failed to update content index")
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to index content", exc_info=True)
|
||||
raise e
|
||||
|
||||
|
||||
|
@ -146,10 +145,14 @@ def configure_routes(app):
|
|||
from khoj.routers.web_client import web_client
|
||||
from khoj.routers.indexer import indexer
|
||||
from khoj.routers.auth import auth_router
|
||||
from khoj.routers.subscription import subscription_router
|
||||
|
||||
app.include_router(api, prefix="/api")
|
||||
app.include_router(api_beta, prefix="/api/beta")
|
||||
app.include_router(indexer, prefix="/api/v1/index")
|
||||
if state.billing_enabled:
|
||||
logger.info("💳 Enabled Billing")
|
||||
app.include_router(subscription_router, prefix="/api/subscription")
|
||||
app.include_router(web_client)
|
||||
app.include_router(auth_router, prefix="/auth")
|
||||
|
||||
|
@ -165,13 +168,15 @@ def update_search_index():
|
|||
logger.info("📬 Updating content index via Scheduler")
|
||||
for user in get_all_users():
|
||||
all_files = collect_files(user=user)
|
||||
state.content_index = configure_content(
|
||||
state.content_index, success = configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models, user=user
|
||||
)
|
||||
all_files = collect_files(user=None)
|
||||
state.content_index = configure_content(
|
||||
state.content_index, success = configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models, user=None
|
||||
)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to update content index")
|
||||
logger.info("📪 Content index updated via Scheduler")
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True)
|
||||
|
|
BIN
src/khoj/interface/web/assets/icons/computer.png
Normal file
BIN
src/khoj/interface/web/assets/icons/computer.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 10 KiB |
BIN
src/khoj/interface/web/assets/icons/credit-card.png
Normal file
BIN
src/khoj/interface/web/assets/icons/credit-card.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 19 KiB |
|
@ -209,23 +209,27 @@
|
|||
border: none;
|
||||
color: var(--flower);
|
||||
padding: 4px;
|
||||
width: 32px;
|
||||
margin-bottom: 0px
|
||||
}
|
||||
|
||||
div.file-element {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr auto;
|
||||
grid-template-columns: 1fr 5fr 1fr;
|
||||
border: 1px solid rgb(229, 229, 229);
|
||||
border-radius: 4px;
|
||||
box-shadow: 0px 1px 3px 0px rgba(0,0,0,0.1),0px 1px 2px -1px rgba(0,0,0,0.8);
|
||||
padding: 4px;
|
||||
padding: 4px 0;
|
||||
margin-bottom: 8px;
|
||||
justify-items: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
div.remove-button-container {
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
button.card-button.happy {
|
||||
.card-button.happy {
|
||||
color: var(--leaf);
|
||||
}
|
||||
|
||||
|
|
|
@ -187,9 +187,9 @@
|
|||
.then(response => {
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let references = null;
|
||||
|
||||
function readStream() {
|
||||
let references = null;
|
||||
reader.read().then(({ done, value }) => {
|
||||
if (done) {
|
||||
// Evaluate the contents of new_response_text.innerHTML after all the data has been streamed
|
||||
|
|
|
@ -3,23 +3,57 @@
|
|||
|
||||
<div class="page">
|
||||
<div class="section">
|
||||
<h2 class="section-title">Plugins</h2>
|
||||
<h2 class="section-title">Content</h2>
|
||||
<div class="section-cards">
|
||||
<div class="card">
|
||||
<div class="card-title-row">
|
||||
<img class="card-icon" src="/static/assets/icons/computer.png" alt="Computer">
|
||||
<h3 id="card-title-computer" class="card-title">
|
||||
Files
|
||||
<img id="configured-icon-computer"
|
||||
style="display: {% if not current_model_state.computer %}none{% endif %}"
|
||||
class="configured-icon"
|
||||
src="/static/assets/icons/confirm-icon.svg"
|
||||
alt="Configured">
|
||||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<p class="card-description">Manage files from your computer</p>
|
||||
</div>
|
||||
<div class="card-action-row">
|
||||
<a class="card-button" href="/config/content-source/computer">
|
||||
{% if current_model_state.computer %}
|
||||
Update
|
||||
{% else %}
|
||||
Setup
|
||||
{% endif %}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
</div>
|
||||
<div id="clear-computer" class="card-action-row"
|
||||
style="display: {% if not current_model_state.computer %}none{% endif %}">
|
||||
<button class="card-button" onclick="clearContentType('computer')">
|
||||
Disable
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="card">
|
||||
<div class="card-title-row">
|
||||
<img class="card-icon" src="/static/assets/icons/github.svg" alt="Github">
|
||||
<h3 class="card-title">
|
||||
Github
|
||||
{% if current_model_state.github == True %}
|
||||
<img id="configured-icon-github" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||
{% endif %}
|
||||
<img id="configured-icon-github"
|
||||
class="configured-icon"
|
||||
src="/static/assets/icons/confirm-icon.svg"
|
||||
alt="Configured"
|
||||
style="display: {% if not current_model_state.github %}none{% endif %}">
|
||||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<p class="card-description">Set repositories to index</p>
|
||||
</div>
|
||||
<div class="card-action-row">
|
||||
<a class="card-button" href="/config/content_type/github">
|
||||
<a class="card-button" href="/config/content-source/github">
|
||||
{% if current_model_state.github %}
|
||||
Update
|
||||
{% else %}
|
||||
|
@ -28,30 +62,32 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
</div>
|
||||
{% if current_model_state.github %}
|
||||
<div id="clear-github" class="card-action-row">
|
||||
<button class="card-button" onclick="clearContentType('github')">
|
||||
Disable
|
||||
</button>
|
||||
</div>
|
||||
{% endif %}
|
||||
<div id="clear-github"
|
||||
class="card-action-row"
|
||||
style="display: {% if not current_model_state.github %}none{% endif %}">
|
||||
<button class="card-button" onclick="clearContentType('github')">
|
||||
Disable
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="card">
|
||||
<div class="card-title-row">
|
||||
<img class="card-icon" src="/static/assets/icons/notion.svg" alt="Notion">
|
||||
<h3 class="card-title">
|
||||
Notion
|
||||
{% if current_model_state.notion == True %}
|
||||
<img id="configured-icon-notion" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||
{% endif %}
|
||||
<img id="configured-icon-notion"
|
||||
class="configured-icon"
|
||||
src="/static/assets/icons/confirm-icon.svg"
|
||||
alt="Configured"
|
||||
style="display: {% if not current_model_state.notion %}none{% endif %}">
|
||||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<p class="card-description">Configure your settings from Notion</p>
|
||||
<p class="card-description">Sync your Notion pages</p>
|
||||
</div>
|
||||
<div class="card-action-row">
|
||||
<a class="card-button" href="/config/content_type/notion">
|
||||
{% if current_model_state.content %}
|
||||
<a class="card-button" href="/config/content-source/notion">
|
||||
{% if current_model_state.notion %}
|
||||
Update
|
||||
{% else %}
|
||||
Setup
|
||||
|
@ -59,13 +95,13 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
</div>
|
||||
{% if current_model_state.notion %}
|
||||
<div id="clear-notion" class="card-action-row">
|
||||
<button class="card-button" onclick="clearContentType('notion')">
|
||||
Disable
|
||||
</button>
|
||||
</div>
|
||||
{% endif %}
|
||||
<div id="clear-notion"
|
||||
class="card-action-row"
|
||||
style="display: {% if not current_model_state.notion %}none{% endif %}">
|
||||
<button class="card-button" onclick="clearContentType('notion')">
|
||||
Disable
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -77,7 +113,7 @@
|
|||
<div class="card-title-row">
|
||||
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
||||
<h3 class="card-title">
|
||||
Chat Model
|
||||
Chat
|
||||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
|
@ -122,16 +158,69 @@
|
|||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{% if billing_enabled %}
|
||||
<div class="section">
|
||||
<h2 class="section-title">Manage Data</h2>
|
||||
<div class="section-manage-files">
|
||||
<div id="delete-all-files" class="delete-all=files">
|
||||
<button id="delete-all-files" type="submit" title="Delete all indexed files">🗑️ Remove All</button>
|
||||
</div>
|
||||
<div class="indexed-files">
|
||||
<h2 class="section-title">Billing</h2>
|
||||
<div class="section-cards">
|
||||
<div class="card">
|
||||
<div class="card-title-row">
|
||||
<img class="card-icon" src="/static/assets/icons/credit-card.png" alt="Credit Card">
|
||||
<h3 class="card-title">
|
||||
<span>Subscription</span>
|
||||
<img id="configured-icon-subscription"
|
||||
style="display: {% if subscription_state == 'trial' or subscription_state == 'expired' %}none{% endif %}"
|
||||
class="configured-icon"
|
||||
src="/static/assets/icons/confirm-icon.svg"
|
||||
alt="Configured">
|
||||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<p id="trial-description"
|
||||
class="card-description"
|
||||
style="display: {% if subscription_state != 'trial' %}none{% endif %}">
|
||||
Subscribe to Khoj Cloud
|
||||
</p>
|
||||
<p id="unsubscribe-description"
|
||||
class="card-description"
|
||||
style="display: {% if subscription_state != 'subscribed' %}none{% endif %}">
|
||||
You are <b>subscribed</b> to Khoj Cloud. Subscription will <b>renew</b> on <b>{{ subscription_renewal_date }}</b>
|
||||
</p>
|
||||
<p id="resubscribe-description"
|
||||
class="card-description"
|
||||
style="display: {% if subscription_state != 'unsubscribed' %}none{% endif %}">
|
||||
You are <b>subscribed</b> to Khoj Cloud. Subscription will <b>expire</b> on <b>{{ subscription_renewal_date }}</b>
|
||||
</p>
|
||||
<p id="expire-description"
|
||||
class="card-description"
|
||||
style="display: {% if subscription_state != 'expired' %}none{% endif %}">
|
||||
Subscribe to Khoj Cloud. Subscription <b>expired</b> on <b>{{ subscription_renewal_date }}</b>
|
||||
</p>
|
||||
</div>
|
||||
<div class="card-action-row">
|
||||
<button id="unsubscribe-button"
|
||||
class="card-button"
|
||||
onclick="unsubscribe()"
|
||||
style="display: {% if subscription_state != 'subscribed' %}none{% endif %};">
|
||||
Unsubscribe
|
||||
</button>
|
||||
<button id="resubscribe-button"
|
||||
class="card-button happy"
|
||||
onclick="resubscribe()"
|
||||
style="display: {% if subscription_state != 'unsubscribed' %}none{% endif %};">
|
||||
Resubscribe
|
||||
</button>
|
||||
<a id="subscribe-button"
|
||||
class="card-button happy"
|
||||
href="{{ khoj_cloud_subscription_url }}?prefilled_email={{ username }}"
|
||||
style="display: {% if subscription_state == 'subscribed' or subscription_state == 'unsubscribed' %}none{% endif %};">
|
||||
Subscribe
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
<div class="section general-settings">
|
||||
<div id="results-count" title="Number of items to show in search and use for chat response">
|
||||
<label for="results-count-slider">Results Count: <span id="results-count-value">5</span></label>
|
||||
|
@ -176,8 +265,9 @@
|
|||
})
|
||||
};
|
||||
|
||||
function clearContentType(content_type) {
|
||||
fetch('/api/config/data/content_type/' + content_type, {
|
||||
function clearContentType(content_source) {
|
||||
|
||||
fetch('/api/config/data/content-source/' + content_source, {
|
||||
method: 'DELETE',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
|
@ -186,22 +276,54 @@
|
|||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.status == "ok") {
|
||||
var contentTypeClearButton = document.getElementById("clear-" + content_type);
|
||||
contentTypeClearButton.style.display = "none";
|
||||
|
||||
var configuredIcon = document.getElementById("configured-icon-" + content_type);
|
||||
if (configuredIcon) {
|
||||
configuredIcon.style.display = "none";
|
||||
}
|
||||
|
||||
var misconfiguredIcon = document.getElementById("misconfigured-icon-" + content_type);
|
||||
if (misconfiguredIcon) {
|
||||
misconfiguredIcon.style.display = "none";
|
||||
}
|
||||
document.getElementById("configured-icon-" + content_source).style.display = "none";
|
||||
document.getElementById("clear-" + content_source).style.display = "none";
|
||||
} else {
|
||||
document.getElementById("configured-icon-" + content_source).style.display = "";
|
||||
document.getElementById("clear-" + content_source).style.display = "";
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
function unsubscribe() {
|
||||
fetch('/api/subscription?operation=cancel&email={{username}}', {
|
||||
method: 'PATCH',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.success) {
|
||||
document.getElementById("unsubscribe-description").style.display = "none";
|
||||
document.getElementById("unsubscribe-button").style.display = "none";
|
||||
|
||||
document.getElementById("resubscribe-description").style.display = "";
|
||||
document.getElementById("resubscribe-button").style.display = "";
|
||||
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
function resubscribe() {
|
||||
fetch('/api/subscription?operation=resubscribe&email={{username}}', {
|
||||
method: 'PATCH',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.success) {
|
||||
document.getElementById("resubscribe-description").style.display = "none";
|
||||
document.getElementById("resubscribe-button").style.display = "none";
|
||||
|
||||
document.getElementById("unsubscribe-description").style.display = "";
|
||||
document.getElementById("unsubscribe-button").style.display = "";
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var configure = document.getElementById("configure");
|
||||
configure.addEventListener("click", function(event) {
|
||||
event.preventDefault();
|
||||
|
@ -243,6 +365,7 @@
|
|||
if (data.detail != null) {
|
||||
throw new Error(data.detail);
|
||||
}
|
||||
|
||||
document.getElementById("status").innerHTML = emoji + " " + successText;
|
||||
document.getElementById("status").style.display = "block";
|
||||
button.disabled = false;
|
||||
|
@ -255,6 +378,26 @@
|
|||
button.disabled = false;
|
||||
button.innerHTML = '⚠️ Unsuccessful';
|
||||
});
|
||||
|
||||
content_sources = ["computer", "github", "notion"];
|
||||
content_sources.forEach(content_source => {
|
||||
fetch(`/api/config/data/${content_source}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.length > 0) {
|
||||
document.getElementById("configured-icon-" + content_source).style.display = "";
|
||||
document.getElementById("clear-" + content_source).style.display = "";
|
||||
} else {
|
||||
document.getElementById("configured-icon-" + content_source).style.display = "none";
|
||||
document.getElementById("clear-" + content_source).style.display = "none";
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Setup the results count slider
|
||||
|
@ -362,70 +505,5 @@
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Get all currently indexed files
|
||||
function getAllFilenames() {
|
||||
fetch('/api/config/data/all')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
var indexedFiles = document.getElementsByClassName("indexed-files")[0];
|
||||
indexedFiles.innerHTML = "";
|
||||
|
||||
if (data.length == 0) {
|
||||
document.getElementById("delete-all-files").style.display = "none";
|
||||
indexedFiles.innerHTML = "<div>Use the <a href='https://download.khoj.dev'>Khoj Desktop client</a> to index files.</div>";
|
||||
} else {
|
||||
document.getElementById("delete-all-files").style.display = "block";
|
||||
}
|
||||
|
||||
for (var filename of data) {
|
||||
let fileElement = document.createElement("div");
|
||||
fileElement.classList.add("file-element");
|
||||
|
||||
let fileNameElement = document.createElement("div");
|
||||
fileNameElement.classList.add("content-name");
|
||||
fileNameElement.innerHTML = filename;
|
||||
fileElement.appendChild(fileNameElement);
|
||||
|
||||
let buttonContainer = document.createElement("div");
|
||||
buttonContainer.classList.add("remove-button-container");
|
||||
let removeFileButton = document.createElement("button");
|
||||
removeFileButton.classList.add("remove-file-button");
|
||||
removeFileButton.innerHTML = "🗑️";
|
||||
removeFileButton.addEventListener("click", ((filename) => {
|
||||
return () => {
|
||||
removeFile(filename);
|
||||
};
|
||||
})(filename));
|
||||
buttonContainer.appendChild(removeFileButton);
|
||||
fileElement.appendChild(buttonContainer);
|
||||
indexedFiles.appendChild(fileElement);
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Error:', error);
|
||||
});
|
||||
}
|
||||
|
||||
// Get all currently indexed files on page load
|
||||
getAllFilenames();
|
||||
|
||||
let deleteAllFilesButton = document.getElementById("delete-all-files");
|
||||
deleteAllFilesButton.addEventListener("click", function(event) {
|
||||
event.preventDefault();
|
||||
fetch('/api/config/data/all', {
|
||||
method: 'DELETE',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.status == "ok") {
|
||||
getAllFilenames();
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
</script>
|
||||
{% endblock %}
|
||||
|
|
129
src/khoj/interface/web/content_source_computer_input.html
Normal file
129
src/khoj/interface/web/content_source_computer_input.html
Normal file
|
@ -0,0 +1,129 @@
|
|||
{% extends "base_config.html" %}
|
||||
{% block content %}
|
||||
<div class="page">
|
||||
<div class="section">
|
||||
<h2 class="section-title">
|
||||
<img class="card-icon" src="/static/assets/icons/computer.png" alt="files">
|
||||
<span class="card-title-text">Files</span>
|
||||
<div class="instructions">
|
||||
<p class="card-description">Manage files from your computer</p>
|
||||
<p id="desktop-client" class="card-description">Download the <a href="https://download.khoj.dev">Khoj Desktop app</a> to sync files from your computer</p>
|
||||
</div>
|
||||
</h2>
|
||||
<div class="section-manage-files">
|
||||
<div id="delete-all-files" class="delete-all-files">
|
||||
<button id="delete-all-files" type="submit" title="Remove all computer files from Khoj">🗑️ Delete all</button>
|
||||
</div>
|
||||
<div class="indexed-files">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<style>
|
||||
#desktop-client {
|
||||
font-weight: normal;
|
||||
}
|
||||
.indexed-files {
|
||||
width: 100%;
|
||||
}
|
||||
.content-name {
|
||||
font-size: smaller;
|
||||
}
|
||||
</style>
|
||||
<script>
|
||||
function removeFile(path) {
|
||||
fetch('/api/config/data/file?filename=' + path, {
|
||||
method: 'DELETE',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
})
|
||||
.then(response => response.ok ? response.json() : Promise.reject(response))
|
||||
.then(data => {
|
||||
if (data.status == "ok") {
|
||||
getAllComputerFilenames();
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Get all currently indexed files
|
||||
function getAllComputerFilenames() {
|
||||
fetch('/api/config/data/computer')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
var indexedFiles = document.getElementsByClassName("indexed-files")[0];
|
||||
indexedFiles.innerHTML = "";
|
||||
|
||||
if (data.length == 0) {
|
||||
document.getElementById("delete-all-files").style.display = "none";
|
||||
indexedFiles.innerHTML = "<div class='card-description'>Use the <a href='https://download.khoj.dev'>Khoj Desktop client</a> to index files.</div>";
|
||||
} else {
|
||||
document.getElementById("delete-all-files").style.display = "block";
|
||||
}
|
||||
|
||||
for (var filename of data) {
|
||||
let fileElement = document.createElement("div");
|
||||
fileElement.classList.add("file-element");
|
||||
|
||||
let fileExtension = filename.split('.').pop();
|
||||
if (fileExtension === "org")
|
||||
image_name = "org.svg"
|
||||
else if (fileExtension === "pdf")
|
||||
image_name = "pdf.svg"
|
||||
else if (fileExtension === "markdown" || fileExtension === "md")
|
||||
image_name = "markdown.svg"
|
||||
else
|
||||
image_name = "plaintext.svg"
|
||||
|
||||
let fileIconElement = document.createElement("img");
|
||||
fileIconElement.classList.add("card-icon");
|
||||
fileIconElement.src = `/static/assets/icons/${image_name}`;
|
||||
fileIconElement.alt = "File";
|
||||
fileElement.appendChild(fileIconElement);
|
||||
|
||||
let fileNameElement = document.createElement("div");
|
||||
fileNameElement.classList.add("content-name");
|
||||
fileNameElement.innerHTML = filename;
|
||||
fileElement.appendChild(fileNameElement);
|
||||
|
||||
let buttonContainer = document.createElement("div");
|
||||
buttonContainer.classList.add("remove-button-container");
|
||||
let removeFileButton = document.createElement("button");
|
||||
removeFileButton.classList.add("remove-file-button");
|
||||
removeFileButton.innerHTML = "🗑️";
|
||||
removeFileButton.addEventListener("click", ((filename) => {
|
||||
return () => {
|
||||
removeFile(filename);
|
||||
};
|
||||
})(filename));
|
||||
buttonContainer.appendChild(removeFileButton);
|
||||
fileElement.appendChild(buttonContainer);
|
||||
indexedFiles.appendChild(fileElement);
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Error:', error);
|
||||
});
|
||||
}
|
||||
|
||||
// Get all currently indexed files on page load
|
||||
getAllComputerFilenames();
|
||||
|
||||
let deleteAllComputerFilesButton = document.getElementById("delete-all-files");
|
||||
deleteAllComputerFilesButton.addEventListener("click", function(event) {
|
||||
event.preventDefault();
|
||||
fetch('/api/config/data/content-source/computer', {
|
||||
method: 'DELETE',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.status == "ok") {
|
||||
getAllComputerFilenames();
|
||||
}
|
||||
})
|
||||
});
|
||||
</script>
|
||||
{% endblock %}
|
|
@ -125,7 +125,7 @@
|
|||
}
|
||||
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
fetch('/api/config/data/content_type/github', {
|
||||
fetch('/api/config/data/content-source/github', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
|
@ -42,7 +42,7 @@
|
|||
}
|
||||
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
fetch('/api/config/data/content_type/notion', {
|
||||
fetch('/api/config/data/content-source/notion', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
|
@ -1,159 +0,0 @@
|
|||
{% extends "base_config.html" %}
|
||||
{% block content %}
|
||||
<div class="page">
|
||||
<div class="section">
|
||||
<h2 class="section-title">
|
||||
<img class="card-icon" src="/static/assets/icons/{{ content_type }}.svg" alt="{{ content_type|capitalize }}">
|
||||
<span class="card-title-text">{{ content_type|capitalize }}</span>
|
||||
</h2>
|
||||
<form id="config-form">
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
<label for="input-files" title="Add a {{content_type}} file for Khoj to index">Files</label>
|
||||
</td>
|
||||
<td id="input-files-cell">
|
||||
{% if current_config['input_files'] is none %}
|
||||
<input type="text" id="input-files" name="input-files" placeholder="~\Documents\notes.{{content_type}}">
|
||||
{% else %}
|
||||
{% for input_file in current_config['input_files'] %}
|
||||
<input type="text" id="input-files" name="input-files" value="{{ input_file }}" placeholder="~\Documents\notes.{{content_type}}">
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
</td>
|
||||
<td>
|
||||
<button type="button" id="input-files-button">Add</button>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<label for="input-filter" title="Add a folder with {{content_type}} files for Khoj to index">Folders</label>
|
||||
</td>
|
||||
<td id="input-filter-cell">
|
||||
{% if current_config['input_filter'] is none %}
|
||||
<input type="text" id="input-filter" name="input-filter" placeholder="~/Documents/{{content_type}}">
|
||||
{% else %}
|
||||
{% for input_filter in current_config['input_filter'] %}
|
||||
<input type="text" id="input-filter" name="input-filter" placeholder="~/Documents/{{content_type}}" value="{{ input_filter }}">
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
</td>
|
||||
<td>
|
||||
<button type="button" id="input-filter-button">Add</button>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
<div class="section">
|
||||
<div id="success" style="display: none;" ></div>
|
||||
<button id="submit" type="submit">Save</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
<script>
|
||||
function addButtonEventListener(fieldName) {
|
||||
var button = document.getElementById(fieldName + "-button");
|
||||
button.addEventListener("click", function(event) {
|
||||
var cell = document.getElementById(fieldName + "-cell");
|
||||
var newInput = document.createElement("input");
|
||||
newInput.setAttribute("type", "text");
|
||||
newInput.setAttribute("name", fieldName);
|
||||
cell.appendChild(newInput);
|
||||
})
|
||||
}
|
||||
|
||||
addButtonEventListener("input-files");
|
||||
addButtonEventListener("input-filter");
|
||||
|
||||
function getValidInputNodes(nodes) {
|
||||
var validNodes = [];
|
||||
for (var i = 0; i < nodes.length; i++) {
|
||||
const nodeValue = nodes[i].value;
|
||||
if (nodeValue === "" || nodeValue === null || nodeValue === undefined || nodeValue === "None") {
|
||||
continue;
|
||||
}
|
||||
validNodes.push(nodes[i]);
|
||||
}
|
||||
return validNodes;
|
||||
}
|
||||
|
||||
submit.addEventListener("click", function(event) {
|
||||
event.preventDefault();
|
||||
let globFormat = "**/*"
|
||||
let suffixes = [];
|
||||
if ('{{content_type}}' == "markdown")
|
||||
suffixes = [".md", ".markdown"]
|
||||
else if ('{{content_type}}' == "org")
|
||||
suffixes = [".org"]
|
||||
else if ('{{content_type}}' === "pdf")
|
||||
suffixes = [".pdf"]
|
||||
else if ('{{content_type}}' === "plaintext")
|
||||
suffixes = ['.*']
|
||||
|
||||
let globs = suffixes.map(x => `${globFormat}${x}`)
|
||||
var inputFileNodes = document.getElementsByName("input-files");
|
||||
var inputFiles = getValidInputNodes(inputFileNodes).map(node => node.value);
|
||||
|
||||
var inputFilterNodes = document.getElementsByName("input-filter");
|
||||
|
||||
var inputFilter = [];
|
||||
var nodes = getValidInputNodes(inputFilterNodes);
|
||||
|
||||
// A regex that checks for globs in the path. If they exist,
|
||||
// we are going to just not add our own globing. If they don't,
|
||||
// then we will assume globbing should be done.
|
||||
const glob_regex = /([*?\[\]])/;
|
||||
if (nodes.length > 0) {
|
||||
for (var i = 0; i < nodes.length; i++) {
|
||||
for (var j = 0; j < globs.length; j++) {
|
||||
if (glob_regex.test(nodes[i].value)) {
|
||||
inputFilter.push(nodes[i].value);
|
||||
} else {
|
||||
inputFilter.push(nodes[i].value + globs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (inputFiles.length === 0 && inputFilter.length === 0) {
|
||||
alert("You must specify at least one input file or input filter.");
|
||||
return;
|
||||
}
|
||||
|
||||
if (inputFiles.length == 0) {
|
||||
inputFiles = null;
|
||||
}
|
||||
|
||||
if (inputFilter.length == 0) {
|
||||
inputFilter = null;
|
||||
}
|
||||
|
||||
// var index_heading_entries = document.getElementById("index-heading-entries").value;
|
||||
var index_heading_entries = true;
|
||||
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
fetch('/api/config/data/content_type/{{ content_type }}', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-CSRFToken': csrfToken
|
||||
},
|
||||
body: JSON.stringify({
|
||||
"input_files": inputFiles,
|
||||
"input_filter": inputFilter,
|
||||
"index_heading_entries": index_heading_entries
|
||||
})
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data["status"] == "ok") {
|
||||
document.getElementById("success").innerHTML = "✅ Successfully updated. Go to your <a href='/config'>settings page</a> to complete setup.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
} else {
|
||||
document.getElementById("success").innerHTML = "⚠️ Failed to update settings.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
}
|
||||
})
|
||||
});
|
||||
</script>
|
||||
{% endblock %}
|
|
@ -104,7 +104,12 @@ class GithubToEntries(TextToEntries):
|
|||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries, DbEntry.EntryType.GITHUB, key="compiled", logger=logger, user=user
|
||||
current_entries,
|
||||
DbEntry.EntryType.GITHUB,
|
||||
DbEntry.EntrySource.GITHUB,
|
||||
key="compiled",
|
||||
logger=logger,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
|
|
@ -47,6 +47,7 @@ class MarkdownToEntries(TextToEntries):
|
|||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
DbEntry.EntryType.MARKDOWN,
|
||||
DbEntry.EntrySource.COMPUTER,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
|
|
|
@ -250,7 +250,12 @@ class NotionToEntries(TextToEntries):
|
|||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries, DbEntry.EntryType.NOTION, key="compiled", logger=logger, user=user
|
||||
current_entries,
|
||||
DbEntry.EntryType.NOTION,
|
||||
DbEntry.EntrySource.NOTION,
|
||||
key="compiled",
|
||||
logger=logger,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
|
|
@ -48,6 +48,7 @@ class OrgToEntries(TextToEntries):
|
|||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
DbEntry.EntryType.ORG,
|
||||
DbEntry.EntrySource.COMPUTER,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
|
|
|
@ -46,6 +46,7 @@ class PdfToEntries(TextToEntries):
|
|||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
DbEntry.EntryType.PDF,
|
||||
DbEntry.EntrySource.COMPUTER,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
|
|
|
@ -56,6 +56,7 @@ class PlaintextToEntries(TextToEntries):
|
|||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
DbEntry.EntryType.PLAINTEXT,
|
||||
DbEntry.EntrySource.COMPUTER,
|
||||
key="compiled",
|
||||
logger=logger,
|
||||
deletion_filenames=deletion_file_names,
|
||||
|
|
|
@ -78,6 +78,7 @@ class TextToEntries(ABC):
|
|||
self,
|
||||
current_entries: List[Entry],
|
||||
file_type: str,
|
||||
file_source: str,
|
||||
key="compiled",
|
||||
logger: logging.Logger = None,
|
||||
deletion_filenames: Set[str] = None,
|
||||
|
@ -93,9 +94,9 @@ class TextToEntries(ABC):
|
|||
|
||||
num_deleted_entries = 0
|
||||
if regenerate:
|
||||
with timer("Prepared dataset for regeneration in", logger):
|
||||
with timer("Cleared existing dataset for regeneration in", logger):
|
||||
logger.debug(f"Deleting all entries for file type {file_type}")
|
||||
num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type)
|
||||
num_deleted_entries = EntryAdapters.delete_all_entries_by_type(user, file_type)
|
||||
|
||||
hashes_to_process = set()
|
||||
with timer("Identified entries to add to database in", logger):
|
||||
|
@ -132,6 +133,7 @@ class TextToEntries(ABC):
|
|||
compiled=entry.compiled,
|
||||
heading=entry.heading[:1000], # Truncate to max chars of field allowed
|
||||
file_path=entry.file,
|
||||
file_source=file_source,
|
||||
file_type=file_type,
|
||||
hashed_value=entry_hash,
|
||||
corpus_id=entry.corpus_id,
|
||||
|
|
|
@ -23,7 +23,6 @@ from khoj.utils.rawconfig import (
|
|||
FullConfig,
|
||||
SearchConfig,
|
||||
SearchResponse,
|
||||
TextContentConfig,
|
||||
GithubContentConfig,
|
||||
NotionContentConfig,
|
||||
)
|
||||
|
@ -51,6 +50,7 @@ from database.models import (
|
|||
LocalPdfConfig,
|
||||
LocalPlaintextConfig,
|
||||
KhojUser,
|
||||
Entry as DbEntry,
|
||||
GithubConfig,
|
||||
NotionConfig,
|
||||
)
|
||||
|
@ -61,11 +61,13 @@ api = APIRouter()
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def map_config_to_object(content_type: str):
|
||||
if content_type == "github":
|
||||
def map_config_to_object(content_source: str):
|
||||
if content_source == DbEntry.EntrySource.GITHUB:
|
||||
return GithubConfig
|
||||
if content_type == "notion":
|
||||
if content_source == DbEntry.EntrySource.GITHUB:
|
||||
return NotionConfig
|
||||
if content_source == DbEntry.EntrySource.COMPUTER:
|
||||
return "Computer"
|
||||
|
||||
|
||||
async def map_config_to_db(config: FullConfig, user: KhojUser):
|
||||
|
@ -164,7 +166,7 @@ async def set_config_data(
|
|||
return state.config
|
||||
|
||||
|
||||
@api.post("/config/data/content_type/github", status_code=200)
|
||||
@api.post("/config/data/content-source/github", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def set_content_config_github_data(
|
||||
request: Request,
|
||||
|
@ -192,7 +194,7 @@ async def set_content_config_github_data(
|
|||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api.post("/config/data/content_type/notion", status_code=200)
|
||||
@api.post("/config/data/content-source/notion", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def set_content_config_notion_data(
|
||||
request: Request,
|
||||
|
@ -219,11 +221,11 @@ async def set_content_config_notion_data(
|
|||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api.delete("/config/data/content_type/{content_type}", status_code=200)
|
||||
@api.delete("/config/data/content-source/{content_source}", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def remove_content_config_data(
|
||||
async def remove_content_source_data(
|
||||
request: Request,
|
||||
content_type: str,
|
||||
content_source: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
@ -233,15 +235,15 @@ async def remove_content_config_data(
|
|||
telemetry_type="api",
|
||||
api="delete_content_config",
|
||||
client=client,
|
||||
metadata={"content_type": content_type},
|
||||
metadata={"content_source": content_source},
|
||||
)
|
||||
|
||||
content_object = map_config_to_object(content_type)
|
||||
content_object = map_config_to_object(content_source)
|
||||
if content_object is None:
|
||||
raise ValueError(f"Invalid content type: {content_type}")
|
||||
|
||||
await content_object.objects.filter(user=user).adelete()
|
||||
await sync_to_async(EntryAdapters.delete_all_entries)(user, content_type)
|
||||
raise ValueError(f"Invalid content source: {content_source}")
|
||||
elif content_object != "Computer":
|
||||
await content_object.objects.filter(user=user).adelete()
|
||||
await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source)
|
||||
|
||||
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
|
||||
return {"status": "ok"}
|
||||
|
@ -268,10 +270,11 @@ async def remove_file_data(
|
|||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api.get("/config/data/all", response_model=List[str])
|
||||
@api.get("/config/data/{content_source}", response_model=List[str])
|
||||
@requires(["authenticated"])
|
||||
async def get_all_filenames(
|
||||
request: Request,
|
||||
content_source: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
@ -283,27 +286,7 @@ async def get_all_filenames(
|
|||
client=client,
|
||||
)
|
||||
|
||||
return await sync_to_async(list)(EntryAdapters.aget_all_filenames(user))
|
||||
|
||||
|
||||
@api.delete("/config/data/all", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def remove_all_config_data(
|
||||
request: Request,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="delete_all_config",
|
||||
client=client,
|
||||
)
|
||||
|
||||
await EntryAdapters.adelete_all_entries(user)
|
||||
|
||||
return {"status": "ok"}
|
||||
return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source))
|
||||
|
||||
|
||||
@api.post("/config/data/conversation/model", status_code=200)
|
||||
|
|
|
@ -24,7 +24,9 @@ logger = logging.getLogger(__name__)
|
|||
auth_router = APIRouter()
|
||||
|
||||
if not state.anonymous_mode and not (os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET")):
|
||||
logger.info("Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET environment variables to use Google OAuth")
|
||||
logger.warn(
|
||||
"🚨 Use --anonymous-mode flag to disable Google OAuth or set GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET environment variables to enable it"
|
||||
)
|
||||
else:
|
||||
config = Config(environ=os.environ)
|
||||
|
||||
|
|
|
@ -126,7 +126,7 @@ async def update(
|
|||
|
||||
# Extract required fields from config
|
||||
loop = asyncio.get_event_loop()
|
||||
state.content_index = await loop.run_in_executor(
|
||||
state.content_index, success = await loop.run_in_executor(
|
||||
None,
|
||||
configure_content,
|
||||
state.content_index,
|
||||
|
@ -138,6 +138,8 @@ async def update(
|
|||
False,
|
||||
user,
|
||||
)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to update content index")
|
||||
logger.info(f"Finished processing batch indexing request")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process batch indexing request: {e}", exc_info=True)
|
||||
|
@ -145,6 +147,7 @@ async def update(
|
|||
f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return Response(content="Failed", status_code=500)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
@ -182,18 +185,19 @@ def configure_content(
|
|||
t: Optional[state.SearchType] = None,
|
||||
full_corpus: bool = True,
|
||||
user: KhojUser = None,
|
||||
) -> Optional[ContentIndex]:
|
||||
) -> tuple[Optional[ContentIndex], bool]:
|
||||
content_index = ContentIndex()
|
||||
|
||||
success = True
|
||||
if t is not None and not t.value in [type.value for type in state.SearchType]:
|
||||
logger.warning(f"🚨 Invalid search type: {t}")
|
||||
return None
|
||||
return None, False
|
||||
|
||||
search_type = t.value if t else None
|
||||
|
||||
if files is None:
|
||||
logger.warning(f"🚨 No files to process for {search_type} search.")
|
||||
return None
|
||||
return None, True
|
||||
|
||||
try:
|
||||
# Initialize Org Notes Search
|
||||
|
@ -209,6 +213,7 @@ def configure_content(
|
|||
)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to setup org: {e}", exc_info=True)
|
||||
success = False
|
||||
|
||||
try:
|
||||
# Initialize Markdown Search
|
||||
|
@ -225,6 +230,7 @@ def configure_content(
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to setup markdown: {e}", exc_info=True)
|
||||
success = False
|
||||
|
||||
try:
|
||||
# Initialize PDF Search
|
||||
|
@ -241,6 +247,7 @@ def configure_content(
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to setup PDF: {e}", exc_info=True)
|
||||
success = False
|
||||
|
||||
try:
|
||||
# Initialize Plaintext Search
|
||||
|
@ -257,6 +264,7 @@ def configure_content(
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to setup plaintext: {e}", exc_info=True)
|
||||
success = False
|
||||
|
||||
try:
|
||||
# Initialize Image Search
|
||||
|
@ -274,6 +282,7 @@ def configure_content(
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
|
||||
success = False
|
||||
|
||||
try:
|
||||
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||
|
@ -291,6 +300,7 @@ def configure_content(
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True)
|
||||
success = False
|
||||
|
||||
try:
|
||||
# Initialize Notion Search
|
||||
|
@ -308,12 +318,13 @@ def configure_content(
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True)
|
||||
success = False
|
||||
|
||||
# Invalidate Query Cache
|
||||
if user:
|
||||
state.query_cache[user.uuid] = LRU()
|
||||
|
||||
return content_index
|
||||
return content_index, success
|
||||
|
||||
|
||||
def load_content(
|
||||
|
|
106
src/khoj/routers/subscription.py
Normal file
106
src/khoj/routers/subscription.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
# Standard Packages
|
||||
from datetime import datetime, timezone
|
||||
import logging
|
||||
import os
|
||||
|
||||
# External Packages
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import APIRouter, Request
|
||||
from starlette.authentication import requires
|
||||
import stripe
|
||||
|
||||
# Internal Packages
|
||||
from database import adapters
|
||||
|
||||
|
||||
# Stripe integration for Khoj Cloud Subscription
|
||||
stripe.api_key = os.getenv("STRIPE_API_KEY")
|
||||
endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET")
|
||||
logger = logging.getLogger(__name__)
|
||||
subscription_router = APIRouter()
|
||||
|
||||
|
||||
@subscription_router.post("")
|
||||
async def subscribe(request: Request):
|
||||
"""Webhook for Stripe to send subscription events to Khoj Cloud"""
|
||||
event = None
|
||||
try:
|
||||
payload = await request.body()
|
||||
sig_header = request.headers["stripe-signature"]
|
||||
event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
raise e
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
raise e
|
||||
|
||||
event_type = event["type"]
|
||||
if event_type not in {
|
||||
"invoice.paid",
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.deleted",
|
||||
"subscription_schedule.canceled",
|
||||
}:
|
||||
logger.warn(f"Unhandled Stripe event type: {event['type']}")
|
||||
return {"success": False}
|
||||
|
||||
# Retrieve the customer's details
|
||||
subscription = event["data"]["object"]
|
||||
customer_id = subscription["customer"]
|
||||
customer = stripe.Customer.retrieve(customer_id)
|
||||
customer_email = customer["email"]
|
||||
|
||||
# Handle valid stripe webhook events
|
||||
success = True
|
||||
if event_type in {"invoice.paid"}:
|
||||
# Mark the user as subscribed and update the next renewal date on payment
|
||||
subscription = stripe.Subscription.list(customer=customer_id).data[0]
|
||||
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
|
||||
user = await adapters.set_user_subscription(customer_email, is_recurring=True, renewal_date=renewal_date)
|
||||
success = user is not None
|
||||
elif event_type in {"customer.subscription.updated"}:
|
||||
user_subscription = await sync_to_async(adapters.get_user_subscription)(customer_email)
|
||||
# Allow updating subscription status if paid user
|
||||
if user_subscription and user_subscription.renewal_date:
|
||||
# Mark user as unsubscribed or resubscribed
|
||||
is_recurring = not subscription["cancel_at_period_end"]
|
||||
updated_user = await adapters.set_user_subscription(customer_email, is_recurring=is_recurring)
|
||||
success = updated_user is not None
|
||||
elif event_type in {"customer.subscription.deleted"}:
|
||||
# Reset the user to trial state
|
||||
user = await adapters.set_user_subscription(
|
||||
customer_email, is_recurring=False, renewal_date=False, type="trial"
|
||||
)
|
||||
success = user is not None
|
||||
|
||||
logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@subscription_router.patch("")
|
||||
@requires(["authenticated"])
|
||||
async def update_subscription(request: Request, email: str, operation: str):
|
||||
# Retrieve the customer's details
|
||||
customers = stripe.Customer.list(email=email).auto_paging_iter()
|
||||
customer = next(customers, None)
|
||||
if customer is None:
|
||||
return {"success": False, "message": "Customer not found"}
|
||||
|
||||
if operation == "cancel":
|
||||
customer_id = customer.id
|
||||
for subscription in stripe.Subscription.list(customer=customer_id):
|
||||
stripe.Subscription.modify(subscription.id, cancel_at_period_end=True)
|
||||
return {"success": True}
|
||||
|
||||
elif operation == "resubscribe":
|
||||
subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter()
|
||||
# Find the subscription that is set to cancel at the end of the period
|
||||
for subscription in subscriptions:
|
||||
if subscription.cancel_at_period_end:
|
||||
# Update the subscription to not cancel at the end of the period
|
||||
stripe.Subscription.modify(subscription.id, cancel_at_period_end=False)
|
||||
return {"success": True}
|
||||
return {"success": False, "message": "No subscription found that is set to cancel"}
|
||||
|
||||
return {"success": False, "message": "Invalid operation"}
|
|
@ -8,8 +8,9 @@ from fastapi import Request
|
|||
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from starlette.authentication import requires
|
||||
from database import adapters
|
||||
from database.models import KhojUser
|
||||
from khoj.utils.rawconfig import (
|
||||
TextContentConfig,
|
||||
GithubContentConfig,
|
||||
GithubRepoConfig,
|
||||
NotionContentConfig,
|
||||
|
@ -17,15 +18,18 @@ from khoj.utils.rawconfig import (
|
|||
|
||||
# Internal Packages
|
||||
from khoj.utils import constants, state
|
||||
from database.adapters import EntryAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters
|
||||
from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
|
||||
from database.adapters import (
|
||||
EntryAdapters,
|
||||
get_user_github_config,
|
||||
get_user_notion_config,
|
||||
ConversationAdapters,
|
||||
get_user_subscription_state,
|
||||
)
|
||||
|
||||
# Initialize Router
|
||||
web_client = APIRouter()
|
||||
templates = Jinja2Templates(directory=constants.web_directory)
|
||||
|
||||
VALID_TEXT_CONTENT_TYPES = ["org", "markdown", "pdf", "plaintext"]
|
||||
|
||||
|
||||
# Create Routes
|
||||
@web_client.get("/", response_class=FileResponse)
|
||||
|
@ -109,41 +113,26 @@ def login_page(request: Request):
|
|||
)
|
||||
|
||||
|
||||
def map_config_to_object(content_type: str):
|
||||
if content_type == "org":
|
||||
return LocalOrgConfig
|
||||
if content_type == "markdown":
|
||||
return LocalMarkdownConfig
|
||||
if content_type == "pdf":
|
||||
return LocalPdfConfig
|
||||
if content_type == "plaintext":
|
||||
return LocalPlaintextConfig
|
||||
|
||||
|
||||
@web_client.get("/config", response_class=HTMLResponse)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def config_page(request: Request):
|
||||
user = request.user.object
|
||||
user: KhojUser = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
enabled_content = set(EntryAdapters.get_unique_file_types(user).all())
|
||||
user_subscription = adapters.get_user_subscription(user.email)
|
||||
user_subscription_state = get_user_subscription_state(user_subscription)
|
||||
subscription_renewal_date = (
|
||||
user_subscription.renewal_date.strftime("%d %b %Y")
|
||||
if user_subscription and user_subscription.renewal_date
|
||||
else None
|
||||
)
|
||||
enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all())
|
||||
|
||||
successfully_configured = {
|
||||
"pdf": ("pdf" in enabled_content),
|
||||
"markdown": ("markdown" in enabled_content),
|
||||
"org": ("org" in enabled_content),
|
||||
"image": False,
|
||||
"github": ("github" in enabled_content),
|
||||
"notion": ("notion" in enabled_content),
|
||||
"plaintext": ("plaintext" in enabled_content),
|
||||
"computer": ("computer" in enabled_content_source),
|
||||
"github": ("github" in enabled_content_source),
|
||||
"notion": ("notion" in enabled_content_source),
|
||||
}
|
||||
|
||||
if state.content_index:
|
||||
successfully_configured.update(
|
||||
{
|
||||
"image": state.content_index.image is not None,
|
||||
}
|
||||
)
|
||||
|
||||
conversation_options = ConversationAdapters.get_conversation_processor_options().all()
|
||||
all_conversation_options = list()
|
||||
for conversation_option in conversation_options:
|
||||
|
@ -157,15 +146,19 @@ def config_page(request: Request):
|
|||
"request": request,
|
||||
"current_model_state": successfully_configured,
|
||||
"anonymous_mode": state.anonymous_mode,
|
||||
"username": user.username if user else None,
|
||||
"username": user.username,
|
||||
"conversation_options": all_conversation_options,
|
||||
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
|
||||
"user_photo": user_picture,
|
||||
"billing_enabled": state.billing_enabled,
|
||||
"subscription_state": user_subscription_state,
|
||||
"subscription_renewal_date": subscription_renewal_date,
|
||||
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@web_client.get("/config/content_type/github", response_class=HTMLResponse)
|
||||
@web_client.get("/config/content-source/github", response_class=HTMLResponse)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def github_config_page(request: Request):
|
||||
user = request.user.object
|
||||
|
@ -192,7 +185,7 @@ def github_config_page(request: Request):
|
|||
current_config = {} # type: ignore
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"content_type_github_input.html",
|
||||
"content_source_github_input.html",
|
||||
context={
|
||||
"request": request,
|
||||
"current_config": current_config,
|
||||
|
@ -202,7 +195,7 @@ def github_config_page(request: Request):
|
|||
)
|
||||
|
||||
|
||||
@web_client.get("/config/content_type/notion", response_class=HTMLResponse)
|
||||
@web_client.get("/config/content-source/notion", response_class=HTMLResponse)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def notion_config_page(request: Request):
|
||||
user = request.user.object
|
||||
|
@ -216,7 +209,7 @@ def notion_config_page(request: Request):
|
|||
current_config = json.loads(current_config.json())
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"content_type_notion_input.html",
|
||||
"content_source_notion_input.html",
|
||||
context={
|
||||
"request": request,
|
||||
"current_config": current_config,
|
||||
|
@ -226,32 +219,16 @@ def notion_config_page(request: Request):
|
|||
)
|
||||
|
||||
|
||||
@web_client.get("/config/content_type/{content_type}", response_class=HTMLResponse)
|
||||
@web_client.get("/config/content-source/computer", response_class=HTMLResponse)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def content_config_page(request: Request, content_type: str):
|
||||
if content_type not in VALID_TEXT_CONTENT_TYPES:
|
||||
return templates.TemplateResponse("config.html", context={"request": request})
|
||||
|
||||
object = map_config_to_object(content_type)
|
||||
def computer_config_page(request: Request):
|
||||
user = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
config = object.objects.filter(user=user).first()
|
||||
if config == None:
|
||||
config = object.objects.create(user=user)
|
||||
|
||||
current_config = TextContentConfig(
|
||||
input_files=config.input_files,
|
||||
input_filter=config.input_filter,
|
||||
index_heading_entries=config.index_heading_entries,
|
||||
)
|
||||
current_config = json.loads(current_config.json())
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"content_type_input.html",
|
||||
"content_source_computer_input.html",
|
||||
context={
|
||||
"request": request,
|
||||
"current_config": current_config,
|
||||
"content_type": content_type,
|
||||
"username": user.username,
|
||||
"user_photo": user_picture,
|
||||
},
|
||||
|
|
|
@ -204,11 +204,12 @@ def setup(
|
|||
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
|
||||
)
|
||||
|
||||
file_names = [file_name for file_name in files]
|
||||
if files:
|
||||
file_names = [file_name for file_name in files]
|
||||
|
||||
logger.info(
|
||||
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
|
||||
)
|
||||
logger.info(
|
||||
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
|
||||
)
|
||||
|
||||
|
||||
def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Standard Packages
|
||||
import os
|
||||
import threading
|
||||
from typing import List, Dict
|
||||
from collections import defaultdict
|
||||
|
@ -35,3 +36,8 @@ khoj_version: str = None
|
|||
device = get_device()
|
||||
chat_on_gpu: bool = True
|
||||
anonymous_mode: bool = False
|
||||
billing_enabled: bool = (
|
||||
os.getenv("STRIPE_API_KEY") is not None
|
||||
and os.getenv("STRIPE_SIGNING_SECRET") is not None
|
||||
and os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL") is not None
|
||||
)
|
||||
|
|
|
@ -196,7 +196,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser):
|
|||
|
||||
# Index Markdown Content for Search
|
||||
all_files = fs_syncer.collect_files(user=default_user2)
|
||||
state.content_index = configure_content(
|
||||
state.content_index, _ = configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
|
||||
)
|
||||
|
||||
|
|
|
@ -64,6 +64,7 @@ def test_encode_docs_memory_leak():
|
|||
batch_size = 20
|
||||
embeddings_model = EmbeddingsModel()
|
||||
memory_usage_trend = []
|
||||
device = f"{helpers.get_device()}".upper()
|
||||
|
||||
# Act
|
||||
# Encode random strings repeatedly and record memory usage trend
|
||||
|
@ -76,8 +77,9 @@ def test_encode_docs_memory_leak():
|
|||
# Calculate slope of line fitting memory usage history
|
||||
memory_usage_trend = np.array(memory_usage_trend)
|
||||
slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend)
|
||||
print(f"Memory usage increased at ~{slope:.2f} MB per iteration on {device}")
|
||||
|
||||
# Assert
|
||||
# If slope is positive memory utilization is increasing
|
||||
# Positive threshold of 2, from observing memory usage trend on MPS vs CPU device
|
||||
assert slope < 2, f"Memory usage increasing at ~{slope:.2f} MB per iteration"
|
||||
assert slope < 2, f"Memory leak suspected on {device}. Memory usage increased at ~{slope:.2f} MB per iteration"
|
||||
|
|
|
@ -58,7 +58,7 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_with_empty_file_raises_error(
|
||||
def test_text_search_setup_with_empty_file_creates_no_entries(
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
|
|
Loading…
Reference in a new issue