diff --git a/docs/github_integration.md b/docs/github_integration.md index 6b8dce48..413dd41e 100644 --- a/docs/github_integration.md +++ b/docs/github_integration.md @@ -9,6 +9,6 @@ The Github integration allows you to index as many repositories as you want. It' ## Use the Github plugin 1. Generate a [classic PAT (personal access token)](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) from [Github](https://github.com/settings/tokens) with `repo` and `admin:org` scopes at least. -2. Navigate to [http://localhost:42110/config/content_type/github](http://localhost:42110/config/content_type/github) to configure your Github settings. Enter in your PAT, along with details for each repository you want to index. +2. Navigate to [http://localhost:42110/config/content-source/github](http://localhost:42110/config/content-source/github) to configure your Github settings. Enter in your PAT, along with details for each repository you want to index. 3. Click `Save`. Go back to the settings page and click `Configure`. 4. Go to [http://localhost:42110/](http://localhost:42110/) and start searching! diff --git a/docs/notion_integration.md b/docs/notion_integration.md index 5fee7ff6..d3b645ca 100644 --- a/docs/notion_integration.md +++ b/docs/notion_integration.md @@ -8,7 +8,7 @@ We haven't setup a fancy integration with OAuth yet, so this integration still r ![setup_new_integration](https://github.com/khoj-ai/khoj/assets/65192171/b056e057-d4dc-47dc-aad3-57b59a22c68b) 3. Share all the workspaces that you want to integrate with the Khoj integration you just made in the previous step ![enable_workspace](https://github.com/khoj-ai/khoj/assets/65192171/98290303-b5b8-4cb0-b32c-f68c6923a3d0) -4. In the first step, you generated an API key. Use the newly generated API Key in your Khoj settings, by default at http://localhost:42110/config/content_type/notion. Click `Save`. +4. In the first step, you generated an API key. Use the newly generated API Key in your Khoj settings, by default at http://localhost:42110/config/content-source/notion. Click `Save`. 5. Click `Configure` in http://localhost:42110/config to index your Notion workspace(s). That's it! You should be ready to start searching and chatting. Make sure you've configured your OpenAI API Key for chat. diff --git a/pyproject.toml b/pyproject.toml index e87d205e..10e44ac0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ dependencies = [ "torch == 2.0.1", "uvicorn == 0.17.6", "aiohttp == 3.8.5", - "langchain >= 0.0.187", + "langchain >= 0.0.331", "requests >= 2.26.0", "bs4 >= 0.0.1", "anyio == 3.7.1", @@ -73,7 +73,8 @@ dependencies = [ "gunicorn == 21.2.0", "lxml == 4.9.3", "tzdata == 2023.3", - "rapidocr-onnxruntime == 1.3.8" + "rapidocr-onnxruntime == 1.3.8", + "stripe == 7.3.0", ] dynamic = ["version"] diff --git a/src/app/settings.py b/src/app/settings.py index 44d1d3d6..1cef3c88 100644 --- a/src/app/settings.py +++ b/src/app/settings.py @@ -21,7 +21,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent.parent # See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/ # SECURITY WARNING: keep the secret key used in production secret! -SECRET_KEY = os.getenv("DJANGO_SECRET_KEY") +SECRET_KEY = os.getenv("KHOJ_DJANGO_SECRET_KEY") # SECURITY WARNING: don't run with debug turned on in production! DEBUG = os.getenv("DJANGO_DEBUG", "False") == "True" diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index fa37aa99..4f71c7aa 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -1,8 +1,8 @@ -from typing import Type, TypeVar, List -from datetime import date +from typing import Optional, Type, TypeVar, List +from datetime import date, datetime, timedelta import secrets from typing import Type, TypeVar, List -from datetime import date +from datetime import date, timezone from django.db import models from django.contrib.sessions.backends.db import SessionStore @@ -30,6 +30,7 @@ from database.models import ( GithubRepoConfig, Conversation, ChatModelOptions, + Subscription, UserConversationConfig, OpenAIProcessorConversationConfig, OfflineChatProcessorConversationConfig, @@ -103,6 +104,57 @@ async def create_google_user(token: dict) -> KhojUser: return user +def get_user_subscription(email: str) -> Optional[Subscription]: + return Subscription.objects.filter(user__email=email).first() + + +async def set_user_subscription( + email: str, is_recurring=None, renewal_date=None, type="standard" +) -> Optional[Subscription]: + user_subscription = await Subscription.objects.filter(user__email=email).afirst() + if not user_subscription: + user = await get_user_by_email(email) + if not user: + return None + user_subscription = await Subscription.objects.acreate( + user=user, type=type, is_recurring=is_recurring, renewal_date=renewal_date + ) + return user_subscription + elif user_subscription: + user_subscription.type = type + if is_recurring is not None: + user_subscription.is_recurring = is_recurring + if renewal_date is False: + user_subscription.renewal_date = None + elif renewal_date is not None: + user_subscription.renewal_date = renewal_date + await user_subscription.asave() + return user_subscription + else: + return None + + +def get_user_subscription_state(user_subscription: Subscription) -> str: + """Get subscription state of user + Valid state transitions: trial -> subscribed <-> unsubscribed OR expired + """ + if not user_subscription: + return "trial" + elif user_subscription.type == Subscription.Type.TRIAL: + return "trial" + elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc): + return "subscribed" + elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc): + return "unsubscribed" + elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc): + return "expired" + return "invalid" + + +async def get_user_by_email(email: str) -> KhojUser: + return await KhojUser.objects.filter(email=email).afirst() + + async def get_user_by_token(token: dict) -> KhojUser: google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst() if not google_user: @@ -287,13 +339,21 @@ class EntryAdapters: return deleted_count @staticmethod - def delete_all_entries(user: KhojUser, file_type: str = None): + def delete_all_entries_by_type(user: KhojUser, file_type: str = None): if file_type is None: deleted_count, _ = Entry.objects.filter(user=user).delete() else: deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete() return deleted_count + @staticmethod + def delete_all_entries(user: KhojUser, file_source: str = None): + if file_source is None: + deleted_count, _ = Entry.objects.filter(user=user).delete() + else: + deleted_count, _ = Entry.objects.filter(user=user, file_source=file_source).delete() + return deleted_count + @staticmethod def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) @@ -318,8 +378,12 @@ class EntryAdapters: return await Entry.objects.filter(user=user, file_path=file_path).adelete() @staticmethod - def aget_all_filenames(user: KhojUser): - return Entry.objects.filter(user=user).distinct("file_path").values_list("file_path", flat=True) + def aget_all_filenames_by_source(user: KhojUser, file_source: str): + return ( + Entry.objects.filter(user=user, file_source=file_source) + .distinct("file_path") + .values_list("file_path", flat=True) + ) @staticmethod async def adelete_all_entries(user: KhojUser): @@ -384,3 +448,7 @@ class EntryAdapters: @staticmethod def get_unique_file_types(user: KhojUser): return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct() + + @staticmethod + def get_unique_file_source(user: KhojUser): + return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct() diff --git a/src/database/migrations/0012_entry_file_source.py b/src/database/migrations/0012_entry_file_source.py new file mode 100644 index 00000000..187136ae --- /dev/null +++ b/src/database/migrations/0012_entry_file_source.py @@ -0,0 +1,21 @@ +# Generated by Django 4.2.5 on 2023-11-07 07:24 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0011_merge_20231102_0138"), + ] + + operations = [ + migrations.AddField( + model_name="entry", + name="file_source", + field=models.CharField( + choices=[("computer", "Computer"), ("notion", "Notion"), ("github", "Github")], + default="computer", + max_length=30, + ), + ), + ] diff --git a/src/database/migrations/0013_subscription.py b/src/database/migrations/0013_subscription.py new file mode 100644 index 00000000..931cea12 --- /dev/null +++ b/src/database/migrations/0013_subscription.py @@ -0,0 +1,37 @@ +# Generated by Django 4.2.5 on 2023-11-09 01:27 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0012_entry_file_source"), + ] + + operations = [ + migrations.CreateModel( + name="Subscription", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "type", + models.CharField( + choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20 + ), + ), + ("is_recurring", models.BooleanField(default=False)), + ("renewal_date", models.DateTimeField(default=None, null=True)), + ( + "user", + models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py index 5dd9622b..28f8cd2a 100644 --- a/src/database/models/__init__.py +++ b/src/database/models/__init__.py @@ -46,6 +46,17 @@ class KhojApiUser(models.Model): accessed_at = models.DateTimeField(null=True, default=None) +class Subscription(BaseModel): + class Type(models.TextChoices): + TRIAL = "trial" + STANDARD = "standard" + + user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) + type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL) + is_recurring = models.BooleanField(default=False) + renewal_date = models.DateTimeField(null=True, default=None) + + class NotionConfig(BaseModel): token = models.CharField(max_length=200) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) @@ -131,11 +142,17 @@ class Entry(BaseModel): GITHUB = "github" CONVERSATION = "conversation" + class EntrySource(models.TextChoices): + COMPUTER = "computer" + NOTION = "notion" + GITHUB = "github" + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True) embeddings = VectorField(dimensions=384) raw = models.TextField() compiled = models.TextField() heading = models.CharField(max_length=1000, default=None, null=True, blank=True) + file_source = models.CharField(max_length=30, choices=EntrySource.choices, default=EntrySource.COMPUTER) file_type = models.CharField(max_length=30, choices=EntryType.choices, default=EntryType.PLAINTEXT) file_path = models.CharField(max_length=400, default=None, null=True, blank=True) file_name = models.CharField(max_length=400, default=None, null=True, blank=True) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 8666b340..302d4a54 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -192,9 +192,9 @@ .then(response => { const reader = response.body.getReader(); const decoder = new TextDecoder(); + let references = null; function readStream() { - let references = null; reader.read().then(({ done, value }) => { if (done) { // Evaluate the contents of new_response_text.innerHTML after all the data has been streamed diff --git a/src/interface/desktop/config.html b/src/interface/desktop/config.html index 0629a5f7..c63a2a5c 100644 --- a/src/interface/desktop/config.html +++ b/src/interface/desktop/config.html @@ -91,11 +91,13 @@
- +
-
- - +
+ +
+
+
@@ -336,7 +338,7 @@ padding: 4px; cursor: pointer; } - #sync-data { + button.sync-data { background-color: var(--primary); border: none; color: var(--main-text-color); @@ -351,7 +353,7 @@ box-shadow: 0px 5px 0px var(--background-color); } - #sync-data:hover { + button.sync-data:hover { background-color: var(--primary-hover); box-shadow: 0px 3px 0px var(--background-color); } diff --git a/src/interface/desktop/main.js b/src/interface/desktop/main.js index 1d5c4be2..ef46ce00 100644 --- a/src/interface/desktop/main.js +++ b/src/interface/desktop/main.js @@ -67,6 +67,7 @@ const schema = { } }; +const syncing = false; var state = {} const store = new Store({ schema }); @@ -110,6 +111,15 @@ function filenameToMimeType (filename) { } function pushDataToKhoj (regenerate = false) { + // Don't sync if token or hostURL is not set or if already syncing + if (store.get('khojToken') === '' || store.get('hostURL') === '' || syncing === true) { + const win = BrowserWindow.getAllWindows()[0]; + if (win) win.webContents.send('update-state', state); + return; + } else { + syncing = true; + } + let filesToPush = []; const files = store.get('files') || []; const folders = store.get('folders') || []; @@ -192,11 +202,13 @@ function pushDataToKhoj (regenerate = false) { }) .finally(() => { // Syncing complete + syncing = false; const win = BrowserWindow.getAllWindows()[0]; if (win) win.webContents.send('update-state', state); }); } else { // Syncing complete + syncing = false; const win = BrowserWindow.getAllWindows()[0]; if (win) win.webContents.send('update-state', state); } @@ -306,6 +318,19 @@ async function syncData (regenerate = false) { } } +async function deleteAllFiles () { + try { + store.set('files', []); + store.set('folders', []); + pushDataToKhoj(true); + const date = new Date(); + console.log('Pushing data to Khoj at: ', date); + } catch (err) { + console.error(err); + } +} + + let firstRun = true; let win = null; const createWindow = (tab = 'chat.html') => { @@ -386,6 +411,7 @@ app.whenReady().then(() => { ipcMain.handle('syncData', (event, regenerate) => { syncData(regenerate); }); + ipcMain.handle('deleteAllFiles', deleteAllFiles); createWindow() diff --git a/src/interface/desktop/preload.js b/src/interface/desktop/preload.js index 3228fdb0..eb5a6cc2 100644 --- a/src/interface/desktop/preload.js +++ b/src/interface/desktop/preload.js @@ -45,7 +45,8 @@ contextBridge.exposeInMainWorld('hostURLAPI', { }) contextBridge.exposeInMainWorld('syncDataAPI', { - syncData: (regenerate) => ipcRenderer.invoke('syncData', regenerate) + syncData: (regenerate) => ipcRenderer.invoke('syncData', regenerate), + deleteAllFiles: () => ipcRenderer.invoke('deleteAllFiles') }) contextBridge.exposeInMainWorld('tokenAPI', { diff --git a/src/interface/desktop/renderer.js b/src/interface/desktop/renderer.js index 1e1fae32..849a8293 100644 --- a/src/interface/desktop/renderer.js +++ b/src/interface/desktop/renderer.js @@ -196,9 +196,19 @@ khojKeyInput.addEventListener('blur', async () => { }); const syncButton = document.getElementById('sync-data'); -const syncForceToggle = document.getElementById('sync-force'); syncButton.addEventListener('click', async () => { loadingBar.style.display = 'block'; - const regenerate = syncForceToggle.checked; - await window.syncDataAPI.syncData(regenerate); + await window.syncDataAPI.syncData(false); +}); + +const syncForceButton = document.getElementById('sync-force'); +syncForceButton.addEventListener('click', async () => { + loadingBar.style.display = 'block'; + await window.syncDataAPI.syncData(true); +}); + +const deleteAllButton = document.getElementById('delete-all'); +deleteAllButton.addEventListener('click', async () => { + loadingBar.style.display = 'block'; + await window.syncDataAPI.deleteAllFiles(); }); diff --git a/src/khoj/configure.py b/src/khoj/configure.py index bc9e9bf8..6f0589a8 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -1,5 +1,4 @@ # Standard Packages -import sys import logging import json from enum import Enum @@ -109,7 +108,6 @@ def configure_server( state.search_models = configure_search(state.search_models, state.config.search_type) initialize_content(regenerate, search_type, init, user) except Exception as e: - logger.error(f"🚨 Failed to configure search models", exc_info=True) raise e finally: state.config_lock.release() @@ -125,7 +123,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non else: logger.info("📬 Updating content index...") all_files = collect_files(user=user) - state.content_index = configure_content( + state.content_index, status = configure_content( state.content_index, state.config.content_type, all_files, @@ -134,8 +132,9 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non search_type, user=user, ) + if not status: + raise RuntimeError("Failed to update content index") except Exception as e: - logger.error(f"🚨 Failed to index content", exc_info=True) raise e @@ -146,10 +145,14 @@ def configure_routes(app): from khoj.routers.web_client import web_client from khoj.routers.indexer import indexer from khoj.routers.auth import auth_router + from khoj.routers.subscription import subscription_router app.include_router(api, prefix="/api") app.include_router(api_beta, prefix="/api/beta") app.include_router(indexer, prefix="/api/v1/index") + if state.billing_enabled: + logger.info("💳 Enabled Billing") + app.include_router(subscription_router, prefix="/api/subscription") app.include_router(web_client) app.include_router(auth_router, prefix="/auth") @@ -165,13 +168,15 @@ def update_search_index(): logger.info("📬 Updating content index via Scheduler") for user in get_all_users(): all_files = collect_files(user=user) - state.content_index = configure_content( + state.content_index, success = configure_content( state.content_index, state.config.content_type, all_files, state.search_models, user=user ) all_files = collect_files(user=None) - state.content_index = configure_content( + state.content_index, success = configure_content( state.content_index, state.config.content_type, all_files, state.search_models, user=None ) + if not success: + raise RuntimeError("Failed to update content index") logger.info("📪 Content index updated via Scheduler") except Exception as e: logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True) diff --git a/src/khoj/interface/web/assets/icons/computer.png b/src/khoj/interface/web/assets/icons/computer.png new file mode 100644 index 00000000..12473485 Binary files /dev/null and b/src/khoj/interface/web/assets/icons/computer.png differ diff --git a/src/khoj/interface/web/assets/icons/credit-card.png b/src/khoj/interface/web/assets/icons/credit-card.png new file mode 100644 index 00000000..487dba5c Binary files /dev/null and b/src/khoj/interface/web/assets/icons/credit-card.png differ diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html index 8e33677c..001ebef8 100644 --- a/src/khoj/interface/web/base_config.html +++ b/src/khoj/interface/web/base_config.html @@ -209,23 +209,27 @@ border: none; color: var(--flower); padding: 4px; + width: 32px; + margin-bottom: 0px } div.file-element { display: grid; - grid-template-columns: 1fr auto; + grid-template-columns: 1fr 5fr 1fr; border: 1px solid rgb(229, 229, 229); border-radius: 4px; box-shadow: 0px 1px 3px 0px rgba(0,0,0,0.1),0px 1px 2px -1px rgba(0,0,0,0.8); - padding: 4px; + padding: 4px 0; margin-bottom: 8px; + justify-items: center; + align-items: center; } div.remove-button-container { text-align: right; } - button.card-button.happy { + .card-button.happy { color: var(--leaf); } diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 61f176c7..9712580d 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -187,9 +187,9 @@ .then(response => { const reader = response.body.getReader(); const decoder = new TextDecoder(); + let references = null; function readStream() { - let references = null; reader.read().then(({ done, value }) => { if (done) { // Evaluate the contents of new_response_text.innerHTML after all the data has been streamed diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index b19bbff6..b2b7fbb3 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -3,23 +3,57 @@
-

Plugins

+

Content

+
+ Computer +

+ Files + Configured +

+
+
+

Manage files from your computer

+
+ +
+ +
+
+
Github

Github - {% if current_model_state.github == True %} - Configured - {% endif %} + Configured

Set repositories to index

- {% if current_model_state.github %} -
- -
- {% endif %} +
+ +
Notion

Notion - {% if current_model_state.notion == True %} - Configured - {% endif %} + Configured

-

Configure your settings from Notion

+

Sync your Notion pages

- {% if current_model_state.notion %} -
- -
- {% endif %} +
+ +
@@ -77,7 +113,7 @@
Chat

- Chat Model + Chat

@@ -122,16 +158,69 @@
+ {% if billing_enabled %}
-

Manage Data

-
-
- -
-
+

Billing

+
+
+
+ Credit Card +

+ Subscription + Configured +

+
+
+

+ Subscribe to Khoj Cloud +

+

+ You are subscribed to Khoj Cloud. Subscription will renew on {{ subscription_renewal_date }} +

+

+ You are subscribed to Khoj Cloud. Subscription will expire on {{ subscription_renewal_date }} +

+

+ Subscribe to Khoj Cloud. Subscription expired on {{ subscription_renewal_date }} +

+
+
+ + + + Subscribe + + +
+ {% endif %}
@@ -176,8 +265,9 @@ }) }; - function clearContentType(content_type) { - fetch('/api/config/data/content_type/' + content_type, { + function clearContentType(content_source) { + + fetch('/api/config/data/content-source/' + content_source, { method: 'DELETE', headers: { 'Content-Type': 'application/json', @@ -186,22 +276,54 @@ .then(response => response.json()) .then(data => { if (data.status == "ok") { - var contentTypeClearButton = document.getElementById("clear-" + content_type); - contentTypeClearButton.style.display = "none"; - - var configuredIcon = document.getElementById("configured-icon-" + content_type); - if (configuredIcon) { - configuredIcon.style.display = "none"; - } - - var misconfiguredIcon = document.getElementById("misconfigured-icon-" + content_type); - if (misconfiguredIcon) { - misconfiguredIcon.style.display = "none"; - } + document.getElementById("configured-icon-" + content_source).style.display = "none"; + document.getElementById("clear-" + content_source).style.display = "none"; + } else { + document.getElementById("configured-icon-" + content_source).style.display = ""; + document.getElementById("clear-" + content_source).style.display = ""; } }) }; + function unsubscribe() { + fetch('/api/subscription?operation=cancel&email={{username}}', { + method: 'PATCH', + headers: { + 'Content-Type': 'application/json', + }, + }) + .then(response => response.json()) + .then(data => { + if (data.success) { + document.getElementById("unsubscribe-description").style.display = "none"; + document.getElementById("unsubscribe-button").style.display = "none"; + + document.getElementById("resubscribe-description").style.display = ""; + document.getElementById("resubscribe-button").style.display = ""; + + } + }) + } + + function resubscribe() { + fetch('/api/subscription?operation=resubscribe&email={{username}}', { + method: 'PATCH', + headers: { + 'Content-Type': 'application/json', + }, + }) + .then(response => response.json()) + .then(data => { + if (data.success) { + document.getElementById("resubscribe-description").style.display = "none"; + document.getElementById("resubscribe-button").style.display = "none"; + + document.getElementById("unsubscribe-description").style.display = ""; + document.getElementById("unsubscribe-button").style.display = ""; + } + }) + } + var configure = document.getElementById("configure"); configure.addEventListener("click", function(event) { event.preventDefault(); @@ -243,6 +365,7 @@ if (data.detail != null) { throw new Error(data.detail); } + document.getElementById("status").innerHTML = emoji + " " + successText; document.getElementById("status").style.display = "block"; button.disabled = false; @@ -255,6 +378,26 @@ button.disabled = false; button.innerHTML = '⚠️ Unsuccessful'; }); + + content_sources = ["computer", "github", "notion"]; + content_sources.forEach(content_source => { + fetch(`/api/config/data/${content_source}`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + } + }) + .then(response => response.json()) + .then(data => { + if (data.length > 0) { + document.getElementById("configured-icon-" + content_source).style.display = ""; + document.getElementById("clear-" + content_source).style.display = ""; + } else { + document.getElementById("configured-icon-" + content_source).style.display = "none"; + document.getElementById("clear-" + content_source).style.display = "none"; + } + }); + }); } // Setup the results count slider @@ -362,70 +505,5 @@ } }) } - - // Get all currently indexed files - function getAllFilenames() { - fetch('/api/config/data/all') - .then(response => response.json()) - .then(data => { - var indexedFiles = document.getElementsByClassName("indexed-files")[0]; - indexedFiles.innerHTML = ""; - - if (data.length == 0) { - document.getElementById("delete-all-files").style.display = "none"; - indexedFiles.innerHTML = "
Use the Khoj Desktop client to index files.
"; - } else { - document.getElementById("delete-all-files").style.display = "block"; - } - - for (var filename of data) { - let fileElement = document.createElement("div"); - fileElement.classList.add("file-element"); - - let fileNameElement = document.createElement("div"); - fileNameElement.classList.add("content-name"); - fileNameElement.innerHTML = filename; - fileElement.appendChild(fileNameElement); - - let buttonContainer = document.createElement("div"); - buttonContainer.classList.add("remove-button-container"); - let removeFileButton = document.createElement("button"); - removeFileButton.classList.add("remove-file-button"); - removeFileButton.innerHTML = "🗑️"; - removeFileButton.addEventListener("click", ((filename) => { - return () => { - removeFile(filename); - }; - })(filename)); - buttonContainer.appendChild(removeFileButton); - fileElement.appendChild(buttonContainer); - indexedFiles.appendChild(fileElement); - } - }) - .catch((error) => { - console.error('Error:', error); - }); - } - - // Get all currently indexed files on page load - getAllFilenames(); - - let deleteAllFilesButton = document.getElementById("delete-all-files"); - deleteAllFilesButton.addEventListener("click", function(event) { - event.preventDefault(); - fetch('/api/config/data/all', { - method: 'DELETE', - headers: { - 'Content-Type': 'application/json', - } - }) - .then(response => response.json()) - .then(data => { - if (data.status == "ok") { - getAllFilenames(); - } - }) - }); - {% endblock %} diff --git a/src/khoj/interface/web/content_source_computer_input.html b/src/khoj/interface/web/content_source_computer_input.html new file mode 100644 index 00000000..aba3d8ee --- /dev/null +++ b/src/khoj/interface/web/content_source_computer_input.html @@ -0,0 +1,129 @@ +{% extends "base_config.html" %} +{% block content %} +
+
+

+ files + Files +
+

Manage files from your computer

+

Download the Khoj Desktop app to sync files from your computer

+
+

+
+
+ +
+
+
+
+
+
+ + +{% endblock %} diff --git a/src/khoj/interface/web/content_type_github_input.html b/src/khoj/interface/web/content_source_github_input.html similarity index 99% rename from src/khoj/interface/web/content_type_github_input.html rename to src/khoj/interface/web/content_source_github_input.html index 0e41645a..ff82b1f2 100644 --- a/src/khoj/interface/web/content_type_github_input.html +++ b/src/khoj/interface/web/content_source_github_input.html @@ -125,7 +125,7 @@ } const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; - fetch('/api/config/data/content_type/github', { + fetch('/api/config/data/content-source/github', { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/src/khoj/interface/web/content_type_notion_input.html b/src/khoj/interface/web/content_source_notion_input.html similarity index 97% rename from src/khoj/interface/web/content_type_notion_input.html rename to src/khoj/interface/web/content_source_notion_input.html index 965c1ef5..18eb5a7f 100644 --- a/src/khoj/interface/web/content_type_notion_input.html +++ b/src/khoj/interface/web/content_source_notion_input.html @@ -42,7 +42,7 @@ } const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; - fetch('/api/config/data/content_type/notion', { + fetch('/api/config/data/content-source/notion', { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/src/khoj/interface/web/content_type_input.html b/src/khoj/interface/web/content_type_input.html deleted file mode 100644 index f8751ddc..00000000 --- a/src/khoj/interface/web/content_type_input.html +++ /dev/null @@ -1,159 +0,0 @@ -{% extends "base_config.html" %} -{% block content %} -
-
-

- {{ content_type|capitalize }} - {{ content_type|capitalize }} -

-
- - - - - - - - - - - -
- - - {% if current_config['input_files'] is none %} - - {% else %} - {% for input_file in current_config['input_files'] %} - - {% endfor %} - {% endif %} - - -
- - - {% if current_config['input_filter'] is none %} - - {% else %} - {% for input_filter in current_config['input_filter'] %} - - {% endfor %} - {% endif %} - - -
-
- - -
-
-
-
- -{% endblock %} diff --git a/src/khoj/processor/github/github_to_entries.py b/src/khoj/processor/github/github_to_entries.py index 14e9b696..56279453 100644 --- a/src/khoj/processor/github/github_to_entries.py +++ b/src/khoj/processor/github/github_to_entries.py @@ -104,7 +104,12 @@ class GithubToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( - current_entries, DbEntry.EntryType.GITHUB, key="compiled", logger=logger, user=user + current_entries, + DbEntry.EntryType.GITHUB, + DbEntry.EntrySource.GITHUB, + key="compiled", + logger=logger, + user=user, ) return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/processor/markdown/markdown_to_entries.py b/src/khoj/processor/markdown/markdown_to_entries.py index e0b76368..0dd71740 100644 --- a/src/khoj/processor/markdown/markdown_to_entries.py +++ b/src/khoj/processor/markdown/markdown_to_entries.py @@ -47,6 +47,7 @@ class MarkdownToEntries(TextToEntries): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, DbEntry.EntryType.MARKDOWN, + DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, diff --git a/src/khoj/processor/notion/notion_to_entries.py b/src/khoj/processor/notion/notion_to_entries.py index a4b15d4e..7a88e2a1 100644 --- a/src/khoj/processor/notion/notion_to_entries.py +++ b/src/khoj/processor/notion/notion_to_entries.py @@ -250,7 +250,12 @@ class NotionToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( - current_entries, DbEntry.EntryType.NOTION, key="compiled", logger=logger, user=user + current_entries, + DbEntry.EntryType.NOTION, + DbEntry.EntrySource.NOTION, + key="compiled", + logger=logger, + user=user, ) return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/processor/org_mode/org_to_entries.py b/src/khoj/processor/org_mode/org_to_entries.py index bf6df6dc..04ce97e4 100644 --- a/src/khoj/processor/org_mode/org_to_entries.py +++ b/src/khoj/processor/org_mode/org_to_entries.py @@ -48,6 +48,7 @@ class OrgToEntries(TextToEntries): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, DbEntry.EntryType.ORG, + DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, diff --git a/src/khoj/processor/pdf/pdf_to_entries.py b/src/khoj/processor/pdf/pdf_to_entries.py index 81c2250f..3a47096a 100644 --- a/src/khoj/processor/pdf/pdf_to_entries.py +++ b/src/khoj/processor/pdf/pdf_to_entries.py @@ -46,6 +46,7 @@ class PdfToEntries(TextToEntries): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, DbEntry.EntryType.PDF, + DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, diff --git a/src/khoj/processor/plaintext/plaintext_to_entries.py b/src/khoj/processor/plaintext/plaintext_to_entries.py index fd5e1de2..d42dae30 100644 --- a/src/khoj/processor/plaintext/plaintext_to_entries.py +++ b/src/khoj/processor/plaintext/plaintext_to_entries.py @@ -56,6 +56,7 @@ class PlaintextToEntries(TextToEntries): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, DbEntry.EntryType.PLAINTEXT, + DbEntry.EntrySource.COMPUTER, key="compiled", logger=logger, deletion_filenames=deletion_file_names, diff --git a/src/khoj/processor/text_to_entries.py b/src/khoj/processor/text_to_entries.py index 501ef5d3..3d79e02e 100644 --- a/src/khoj/processor/text_to_entries.py +++ b/src/khoj/processor/text_to_entries.py @@ -78,6 +78,7 @@ class TextToEntries(ABC): self, current_entries: List[Entry], file_type: str, + file_source: str, key="compiled", logger: logging.Logger = None, deletion_filenames: Set[str] = None, @@ -93,9 +94,9 @@ class TextToEntries(ABC): num_deleted_entries = 0 if regenerate: - with timer("Prepared dataset for regeneration in", logger): + with timer("Cleared existing dataset for regeneration in", logger): logger.debug(f"Deleting all entries for file type {file_type}") - num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type) + num_deleted_entries = EntryAdapters.delete_all_entries_by_type(user, file_type) hashes_to_process = set() with timer("Identified entries to add to database in", logger): @@ -132,6 +133,7 @@ class TextToEntries(ABC): compiled=entry.compiled, heading=entry.heading[:1000], # Truncate to max chars of field allowed file_path=entry.file, + file_source=file_source, file_type=file_type, hashed_value=entry_hash, corpus_id=entry.corpus_id, diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 84e63b09..81e805c6 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -23,7 +23,6 @@ from khoj.utils.rawconfig import ( FullConfig, SearchConfig, SearchResponse, - TextContentConfig, GithubContentConfig, NotionContentConfig, ) @@ -51,6 +50,7 @@ from database.models import ( LocalPdfConfig, LocalPlaintextConfig, KhojUser, + Entry as DbEntry, GithubConfig, NotionConfig, ) @@ -61,11 +61,13 @@ api = APIRouter() logger = logging.getLogger(__name__) -def map_config_to_object(content_type: str): - if content_type == "github": +def map_config_to_object(content_source: str): + if content_source == DbEntry.EntrySource.GITHUB: return GithubConfig - if content_type == "notion": + if content_source == DbEntry.EntrySource.GITHUB: return NotionConfig + if content_source == DbEntry.EntrySource.COMPUTER: + return "Computer" async def map_config_to_db(config: FullConfig, user: KhojUser): @@ -164,7 +166,7 @@ async def set_config_data( return state.config -@api.post("/config/data/content_type/github", status_code=200) +@api.post("/config/data/content-source/github", status_code=200) @requires(["authenticated"]) async def set_content_config_github_data( request: Request, @@ -192,7 +194,7 @@ async def set_content_config_github_data( return {"status": "ok"} -@api.post("/config/data/content_type/notion", status_code=200) +@api.post("/config/data/content-source/notion", status_code=200) @requires(["authenticated"]) async def set_content_config_notion_data( request: Request, @@ -219,11 +221,11 @@ async def set_content_config_notion_data( return {"status": "ok"} -@api.delete("/config/data/content_type/{content_type}", status_code=200) +@api.delete("/config/data/content-source/{content_source}", status_code=200) @requires(["authenticated"]) -async def remove_content_config_data( +async def remove_content_source_data( request: Request, - content_type: str, + content_source: str, client: Optional[str] = None, ): user = request.user.object @@ -233,15 +235,15 @@ async def remove_content_config_data( telemetry_type="api", api="delete_content_config", client=client, - metadata={"content_type": content_type}, + metadata={"content_source": content_source}, ) - content_object = map_config_to_object(content_type) + content_object = map_config_to_object(content_source) if content_object is None: - raise ValueError(f"Invalid content type: {content_type}") - - await content_object.objects.filter(user=user).adelete() - await sync_to_async(EntryAdapters.delete_all_entries)(user, content_type) + raise ValueError(f"Invalid content source: {content_source}") + elif content_object != "Computer": + await content_object.objects.filter(user=user).adelete() + await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source) enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user) return {"status": "ok"} @@ -268,10 +270,11 @@ async def remove_file_data( return {"status": "ok"} -@api.get("/config/data/all", response_model=List[str]) +@api.get("/config/data/{content_source}", response_model=List[str]) @requires(["authenticated"]) async def get_all_filenames( request: Request, + content_source: str, client: Optional[str] = None, ): user = request.user.object @@ -283,27 +286,7 @@ async def get_all_filenames( client=client, ) - return await sync_to_async(list)(EntryAdapters.aget_all_filenames(user)) - - -@api.delete("/config/data/all", status_code=200) -@requires(["authenticated"]) -async def remove_all_config_data( - request: Request, - client: Optional[str] = None, -): - user = request.user.object - - update_telemetry_state( - request=request, - telemetry_type="api", - api="delete_all_config", - client=client, - ) - - await EntryAdapters.adelete_all_entries(user) - - return {"status": "ok"} + return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) @api.post("/config/data/conversation/model", status_code=200) diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py index ebabeb8e..4a3cbcef 100644 --- a/src/khoj/routers/auth.py +++ b/src/khoj/routers/auth.py @@ -24,7 +24,9 @@ logger = logging.getLogger(__name__) auth_router = APIRouter() if not state.anonymous_mode and not (os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET")): - logger.info("Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET environment variables to use Google OAuth") + logger.warn( + "🚨 Use --anonymous-mode flag to disable Google OAuth or set GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET environment variables to enable it" + ) else: config = Config(environ=os.environ) diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py index 1bbf53c2..a7a1249d 100644 --- a/src/khoj/routers/indexer.py +++ b/src/khoj/routers/indexer.py @@ -126,7 +126,7 @@ async def update( # Extract required fields from config loop = asyncio.get_event_loop() - state.content_index = await loop.run_in_executor( + state.content_index, success = await loop.run_in_executor( None, configure_content, state.content_index, @@ -138,6 +138,8 @@ async def update( False, user, ) + if not success: + raise RuntimeError("Failed to update content index") logger.info(f"Finished processing batch indexing request") except Exception as e: logger.error(f"Failed to process batch indexing request: {e}", exc_info=True) @@ -145,6 +147,7 @@ async def update( f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}", exc_info=True, ) + return Response(content="Failed", status_code=500) update_telemetry_state( request=request, @@ -182,18 +185,19 @@ def configure_content( t: Optional[state.SearchType] = None, full_corpus: bool = True, user: KhojUser = None, -) -> Optional[ContentIndex]: +) -> tuple[Optional[ContentIndex], bool]: content_index = ContentIndex() + success = True if t is not None and not t.value in [type.value for type in state.SearchType]: logger.warning(f"🚨 Invalid search type: {t}") - return None + return None, False search_type = t.value if t else None if files is None: logger.warning(f"🚨 No files to process for {search_type} search.") - return None + return None, True try: # Initialize Org Notes Search @@ -209,6 +213,7 @@ def configure_content( ) except Exception as e: logger.error(f"🚨 Failed to setup org: {e}", exc_info=True) + success = False try: # Initialize Markdown Search @@ -225,6 +230,7 @@ def configure_content( except Exception as e: logger.error(f"🚨 Failed to setup markdown: {e}", exc_info=True) + success = False try: # Initialize PDF Search @@ -241,6 +247,7 @@ def configure_content( except Exception as e: logger.error(f"🚨 Failed to setup PDF: {e}", exc_info=True) + success = False try: # Initialize Plaintext Search @@ -257,6 +264,7 @@ def configure_content( except Exception as e: logger.error(f"🚨 Failed to setup plaintext: {e}", exc_info=True) + success = False try: # Initialize Image Search @@ -274,6 +282,7 @@ def configure_content( except Exception as e: logger.error(f"🚨 Failed to setup images: {e}", exc_info=True) + success = False try: github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() @@ -291,6 +300,7 @@ def configure_content( except Exception as e: logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True) + success = False try: # Initialize Notion Search @@ -308,12 +318,13 @@ def configure_content( except Exception as e: logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True) + success = False # Invalidate Query Cache if user: state.query_cache[user.uuid] = LRU() - return content_index + return content_index, success def load_content( diff --git a/src/khoj/routers/subscription.py b/src/khoj/routers/subscription.py new file mode 100644 index 00000000..3457b671 --- /dev/null +++ b/src/khoj/routers/subscription.py @@ -0,0 +1,106 @@ +# Standard Packages +from datetime import datetime, timezone +import logging +import os + +# External Packages +from asgiref.sync import sync_to_async +from fastapi import APIRouter, Request +from starlette.authentication import requires +import stripe + +# Internal Packages +from database import adapters + + +# Stripe integration for Khoj Cloud Subscription +stripe.api_key = os.getenv("STRIPE_API_KEY") +endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET") +logger = logging.getLogger(__name__) +subscription_router = APIRouter() + + +@subscription_router.post("") +async def subscribe(request: Request): + """Webhook for Stripe to send subscription events to Khoj Cloud""" + event = None + try: + payload = await request.body() + sig_header = request.headers["stripe-signature"] + event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret) + except ValueError as e: + # Invalid payload + raise e + except stripe.error.SignatureVerificationError as e: + # Invalid signature + raise e + + event_type = event["type"] + if event_type not in { + "invoice.paid", + "customer.subscription.updated", + "customer.subscription.deleted", + "subscription_schedule.canceled", + }: + logger.warn(f"Unhandled Stripe event type: {event['type']}") + return {"success": False} + + # Retrieve the customer's details + subscription = event["data"]["object"] + customer_id = subscription["customer"] + customer = stripe.Customer.retrieve(customer_id) + customer_email = customer["email"] + + # Handle valid stripe webhook events + success = True + if event_type in {"invoice.paid"}: + # Mark the user as subscribed and update the next renewal date on payment + subscription = stripe.Subscription.list(customer=customer_id).data[0] + renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc) + user = await adapters.set_user_subscription(customer_email, is_recurring=True, renewal_date=renewal_date) + success = user is not None + elif event_type in {"customer.subscription.updated"}: + user_subscription = await sync_to_async(adapters.get_user_subscription)(customer_email) + # Allow updating subscription status if paid user + if user_subscription and user_subscription.renewal_date: + # Mark user as unsubscribed or resubscribed + is_recurring = not subscription["cancel_at_period_end"] + updated_user = await adapters.set_user_subscription(customer_email, is_recurring=is_recurring) + success = updated_user is not None + elif event_type in {"customer.subscription.deleted"}: + # Reset the user to trial state + user = await adapters.set_user_subscription( + customer_email, is_recurring=False, renewal_date=False, type="trial" + ) + success = user is not None + + logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}') + return {"success": success} + + +@subscription_router.patch("") +@requires(["authenticated"]) +async def update_subscription(request: Request, email: str, operation: str): + # Retrieve the customer's details + customers = stripe.Customer.list(email=email).auto_paging_iter() + customer = next(customers, None) + if customer is None: + return {"success": False, "message": "Customer not found"} + + if operation == "cancel": + customer_id = customer.id + for subscription in stripe.Subscription.list(customer=customer_id): + stripe.Subscription.modify(subscription.id, cancel_at_period_end=True) + return {"success": True} + + elif operation == "resubscribe": + subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter() + # Find the subscription that is set to cancel at the end of the period + for subscription in subscriptions: + if subscription.cancel_at_period_end: + # Update the subscription to not cancel at the end of the period + stripe.Subscription.modify(subscription.id, cancel_at_period_end=False) + return {"success": True} + return {"success": False, "message": "No subscription found that is set to cancel"} + + return {"success": False, "message": "Invalid operation"} diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 35603e18..229cee64 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -8,8 +8,9 @@ from fastapi import Request from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse from fastapi.templating import Jinja2Templates from starlette.authentication import requires +from database import adapters +from database.models import KhojUser from khoj.utils.rawconfig import ( - TextContentConfig, GithubContentConfig, GithubRepoConfig, NotionContentConfig, @@ -17,15 +18,18 @@ from khoj.utils.rawconfig import ( # Internal Packages from khoj.utils import constants, state -from database.adapters import EntryAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters -from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig +from database.adapters import ( + EntryAdapters, + get_user_github_config, + get_user_notion_config, + ConversationAdapters, + get_user_subscription_state, +) # Initialize Router web_client = APIRouter() templates = Jinja2Templates(directory=constants.web_directory) -VALID_TEXT_CONTENT_TYPES = ["org", "markdown", "pdf", "plaintext"] - # Create Routes @web_client.get("/", response_class=FileResponse) @@ -109,41 +113,26 @@ def login_page(request: Request): ) -def map_config_to_object(content_type: str): - if content_type == "org": - return LocalOrgConfig - if content_type == "markdown": - return LocalMarkdownConfig - if content_type == "pdf": - return LocalPdfConfig - if content_type == "plaintext": - return LocalPlaintextConfig - - @web_client.get("/config", response_class=HTMLResponse) @requires(["authenticated"], redirect="login_page") def config_page(request: Request): - user = request.user.object + user: KhojUser = request.user.object user_picture = request.session.get("user", {}).get("picture") - enabled_content = set(EntryAdapters.get_unique_file_types(user).all()) + user_subscription = adapters.get_user_subscription(user.email) + user_subscription_state = get_user_subscription_state(user_subscription) + subscription_renewal_date = ( + user_subscription.renewal_date.strftime("%d %b %Y") + if user_subscription and user_subscription.renewal_date + else None + ) + enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all()) successfully_configured = { - "pdf": ("pdf" in enabled_content), - "markdown": ("markdown" in enabled_content), - "org": ("org" in enabled_content), - "image": False, - "github": ("github" in enabled_content), - "notion": ("notion" in enabled_content), - "plaintext": ("plaintext" in enabled_content), + "computer": ("computer" in enabled_content_source), + "github": ("github" in enabled_content_source), + "notion": ("notion" in enabled_content_source), } - if state.content_index: - successfully_configured.update( - { - "image": state.content_index.image is not None, - } - ) - conversation_options = ConversationAdapters.get_conversation_processor_options().all() all_conversation_options = list() for conversation_option in conversation_options: @@ -157,15 +146,19 @@ def config_page(request: Request): "request": request, "current_model_state": successfully_configured, "anonymous_mode": state.anonymous_mode, - "username": user.username if user else None, + "username": user.username, "conversation_options": all_conversation_options, "selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None, "user_photo": user_picture, + "billing_enabled": state.billing_enabled, + "subscription_state": user_subscription_state, + "subscription_renewal_date": subscription_renewal_date, + "khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"), }, ) -@web_client.get("/config/content_type/github", response_class=HTMLResponse) +@web_client.get("/config/content-source/github", response_class=HTMLResponse) @requires(["authenticated"], redirect="login_page") def github_config_page(request: Request): user = request.user.object @@ -192,7 +185,7 @@ def github_config_page(request: Request): current_config = {} # type: ignore return templates.TemplateResponse( - "content_type_github_input.html", + "content_source_github_input.html", context={ "request": request, "current_config": current_config, @@ -202,7 +195,7 @@ def github_config_page(request: Request): ) -@web_client.get("/config/content_type/notion", response_class=HTMLResponse) +@web_client.get("/config/content-source/notion", response_class=HTMLResponse) @requires(["authenticated"], redirect="login_page") def notion_config_page(request: Request): user = request.user.object @@ -216,7 +209,7 @@ def notion_config_page(request: Request): current_config = json.loads(current_config.json()) return templates.TemplateResponse( - "content_type_notion_input.html", + "content_source_notion_input.html", context={ "request": request, "current_config": current_config, @@ -226,32 +219,16 @@ def notion_config_page(request: Request): ) -@web_client.get("/config/content_type/{content_type}", response_class=HTMLResponse) +@web_client.get("/config/content-source/computer", response_class=HTMLResponse) @requires(["authenticated"], redirect="login_page") -def content_config_page(request: Request, content_type: str): - if content_type not in VALID_TEXT_CONTENT_TYPES: - return templates.TemplateResponse("config.html", context={"request": request}) - - object = map_config_to_object(content_type) +def computer_config_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - config = object.objects.filter(user=user).first() - if config == None: - config = object.objects.create(user=user) - - current_config = TextContentConfig( - input_files=config.input_files, - input_filter=config.input_filter, - index_heading_entries=config.index_heading_entries, - ) - current_config = json.loads(current_config.json()) return templates.TemplateResponse( - "content_type_input.html", + "content_source_computer_input.html", context={ "request": request, - "current_config": current_config, - "content_type": content_type, "username": user.username, "user_photo": user_picture, }, diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 14f5b770..ba2fc9ec 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -204,11 +204,12 @@ def setup( files=files, full_corpus=full_corpus, user=user, regenerate=regenerate ) - file_names = [file_name for file_name in files] + if files: + file_names = [file_name for file_name in files] - logger.info( - f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}" - ) + logger.info( + f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}" + ) def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]: diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 748ca15a..098ae35e 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -1,4 +1,5 @@ # Standard Packages +import os import threading from typing import List, Dict from collections import defaultdict @@ -35,3 +36,8 @@ khoj_version: str = None device = get_device() chat_on_gpu: bool = True anonymous_mode: bool = False +billing_enabled: bool = ( + os.getenv("STRIPE_API_KEY") is not None + and os.getenv("STRIPE_SIGNING_SECRET") is not None + and os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL") is not None +) diff --git a/tests/conftest.py b/tests/conftest.py index fbb98476..59104123 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -196,7 +196,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser): # Index Markdown Content for Search all_files = fs_syncer.collect_files(user=default_user2) - state.content_index = configure_content( + state.content_index, _ = configure_content( state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2 ) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 30499049..fdd29b02 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -64,6 +64,7 @@ def test_encode_docs_memory_leak(): batch_size = 20 embeddings_model = EmbeddingsModel() memory_usage_trend = [] + device = f"{helpers.get_device()}".upper() # Act # Encode random strings repeatedly and record memory usage trend @@ -76,8 +77,9 @@ def test_encode_docs_memory_leak(): # Calculate slope of line fitting memory usage history memory_usage_trend = np.array(memory_usage_trend) slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend) + print(f"Memory usage increased at ~{slope:.2f} MB per iteration on {device}") # Assert # If slope is positive memory utilization is increasing # Positive threshold of 2, from observing memory usage trend on MPS vs CPU device - assert slope < 2, f"Memory usage increasing at ~{slope:.2f} MB per iteration" + assert slope < 2, f"Memory leak suspected on {device}. Memory usage increased at ~{slope:.2f} MB per iteration" diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 7d8c30fb..3d729ab5 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -58,7 +58,7 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_text_search_setup_with_empty_file_raises_error( +def test_text_search_setup_with_empty_file_creates_no_entries( org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog ): # Arrange