@@ -176,8 +265,9 @@
})
};
- function clearContentType(content_type) {
- fetch('/api/config/data/content_type/' + content_type, {
+ function clearContentType(content_source) {
+
+ fetch('/api/config/data/content-source/' + content_source, {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
@@ -186,22 +276,54 @@
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
- var contentTypeClearButton = document.getElementById("clear-" + content_type);
- contentTypeClearButton.style.display = "none";
-
- var configuredIcon = document.getElementById("configured-icon-" + content_type);
- if (configuredIcon) {
- configuredIcon.style.display = "none";
- }
-
- var misconfiguredIcon = document.getElementById("misconfigured-icon-" + content_type);
- if (misconfiguredIcon) {
- misconfiguredIcon.style.display = "none";
- }
+ document.getElementById("configured-icon-" + content_source).style.display = "none";
+ document.getElementById("clear-" + content_source).style.display = "none";
+ } else {
+ document.getElementById("configured-icon-" + content_source).style.display = "";
+ document.getElementById("clear-" + content_source).style.display = "";
}
})
};
+ function unsubscribe() {
+ fetch('/api/subscription?operation=cancel&email={{username}}', {
+ method: 'PATCH',
+ headers: {
+ 'Content-Type': 'application/json',
+ },
+ })
+ .then(response => response.json())
+ .then(data => {
+ if (data.success) {
+ document.getElementById("unsubscribe-description").style.display = "none";
+ document.getElementById("unsubscribe-button").style.display = "none";
+
+ document.getElementById("resubscribe-description").style.display = "";
+ document.getElementById("resubscribe-button").style.display = "";
+
+ }
+ })
+ }
+
+ function resubscribe() {
+ fetch('/api/subscription?operation=resubscribe&email={{username}}', {
+ method: 'PATCH',
+ headers: {
+ 'Content-Type': 'application/json',
+ },
+ })
+ .then(response => response.json())
+ .then(data => {
+ if (data.success) {
+ document.getElementById("resubscribe-description").style.display = "none";
+ document.getElementById("resubscribe-button").style.display = "none";
+
+ document.getElementById("unsubscribe-description").style.display = "";
+ document.getElementById("unsubscribe-button").style.display = "";
+ }
+ })
+ }
+
var configure = document.getElementById("configure");
configure.addEventListener("click", function(event) {
event.preventDefault();
@@ -243,6 +365,7 @@
if (data.detail != null) {
throw new Error(data.detail);
}
+
document.getElementById("status").innerHTML = emoji + " " + successText;
document.getElementById("status").style.display = "block";
button.disabled = false;
@@ -255,6 +378,26 @@
button.disabled = false;
button.innerHTML = '⚠️ Unsuccessful';
});
+
+ content_sources = ["computer", "github", "notion"];
+ content_sources.forEach(content_source => {
+ fetch(`/api/config/data/${content_source}`, {
+ method: 'GET',
+ headers: {
+ 'Content-Type': 'application/json',
+ }
+ })
+ .then(response => response.json())
+ .then(data => {
+ if (data.length > 0) {
+ document.getElementById("configured-icon-" + content_source).style.display = "";
+ document.getElementById("clear-" + content_source).style.display = "";
+ } else {
+ document.getElementById("configured-icon-" + content_source).style.display = "none";
+ document.getElementById("clear-" + content_source).style.display = "none";
+ }
+ });
+ });
}
// Setup the results count slider
@@ -362,70 +505,5 @@
}
})
}
-
- // Get all currently indexed files
- function getAllFilenames() {
- fetch('/api/config/data/all')
- .then(response => response.json())
- .then(data => {
- var indexedFiles = document.getElementsByClassName("indexed-files")[0];
- indexedFiles.innerHTML = "";
-
- if (data.length == 0) {
- document.getElementById("delete-all-files").style.display = "none";
- indexedFiles.innerHTML = "
";
- } else {
- document.getElementById("delete-all-files").style.display = "block";
- }
-
- for (var filename of data) {
- let fileElement = document.createElement("div");
- fileElement.classList.add("file-element");
-
- let fileNameElement = document.createElement("div");
- fileNameElement.classList.add("content-name");
- fileNameElement.innerHTML = filename;
- fileElement.appendChild(fileNameElement);
-
- let buttonContainer = document.createElement("div");
- buttonContainer.classList.add("remove-button-container");
- let removeFileButton = document.createElement("button");
- removeFileButton.classList.add("remove-file-button");
- removeFileButton.innerHTML = "🗑️";
- removeFileButton.addEventListener("click", ((filename) => {
- return () => {
- removeFile(filename);
- };
- })(filename));
- buttonContainer.appendChild(removeFileButton);
- fileElement.appendChild(buttonContainer);
- indexedFiles.appendChild(fileElement);
- }
- })
- .catch((error) => {
- console.error('Error:', error);
- });
- }
-
- // Get all currently indexed files on page load
- getAllFilenames();
-
- let deleteAllFilesButton = document.getElementById("delete-all-files");
- deleteAllFilesButton.addEventListener("click", function(event) {
- event.preventDefault();
- fetch('/api/config/data/all', {
- method: 'DELETE',
- headers: {
- 'Content-Type': 'application/json',
- }
- })
- .then(response => response.json())
- .then(data => {
- if (data.status == "ok") {
- getAllFilenames();
- }
- })
- });
-
{% endblock %}
diff --git a/src/khoj/interface/web/content_source_computer_input.html b/src/khoj/interface/web/content_source_computer_input.html
new file mode 100644
index 00000000..aba3d8ee
--- /dev/null
+++ b/src/khoj/interface/web/content_source_computer_input.html
@@ -0,0 +1,129 @@
+{% extends "base_config.html" %}
+{% block content %}
+
+
+
+
+ Files
+
+
Manage files from your computer
+
Download the Khoj Desktop app to sync files from your computer
+
+
+
+
+
+
+
+
+
+
+
+
+
+{% endblock %}
diff --git a/src/khoj/interface/web/content_type_github_input.html b/src/khoj/interface/web/content_source_github_input.html
similarity index 99%
rename from src/khoj/interface/web/content_type_github_input.html
rename to src/khoj/interface/web/content_source_github_input.html
index 0e41645a..ff82b1f2 100644
--- a/src/khoj/interface/web/content_type_github_input.html
+++ b/src/khoj/interface/web/content_source_github_input.html
@@ -125,7 +125,7 @@
}
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
- fetch('/api/config/data/content_type/github', {
+ fetch('/api/config/data/content-source/github', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
diff --git a/src/khoj/interface/web/content_type_notion_input.html b/src/khoj/interface/web/content_source_notion_input.html
similarity index 97%
rename from src/khoj/interface/web/content_type_notion_input.html
rename to src/khoj/interface/web/content_source_notion_input.html
index 965c1ef5..18eb5a7f 100644
--- a/src/khoj/interface/web/content_type_notion_input.html
+++ b/src/khoj/interface/web/content_source_notion_input.html
@@ -42,7 +42,7 @@
}
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
- fetch('/api/config/data/content_type/notion', {
+ fetch('/api/config/data/content-source/notion', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
diff --git a/src/khoj/interface/web/content_type_input.html b/src/khoj/interface/web/content_type_input.html
deleted file mode 100644
index f8751ddc..00000000
--- a/src/khoj/interface/web/content_type_input.html
+++ /dev/null
@@ -1,159 +0,0 @@
-{% extends "base_config.html" %}
-{% block content %}
-
-
-
-
- {{ content_type|capitalize }}
-
-
-
-
-
-{% endblock %}
diff --git a/src/khoj/processor/github/github_to_entries.py b/src/khoj/processor/github/github_to_entries.py
index 14e9b696..56279453 100644
--- a/src/khoj/processor/github/github_to_entries.py
+++ b/src/khoj/processor/github/github_to_entries.py
@@ -104,7 +104,12 @@ class GithubToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
- current_entries, DbEntry.EntryType.GITHUB, key="compiled", logger=logger, user=user
+ current_entries,
+ DbEntry.EntryType.GITHUB,
+ DbEntry.EntrySource.GITHUB,
+ key="compiled",
+ logger=logger,
+ user=user,
)
return num_new_embeddings, num_deleted_embeddings
diff --git a/src/khoj/processor/markdown/markdown_to_entries.py b/src/khoj/processor/markdown/markdown_to_entries.py
index e0b76368..0dd71740 100644
--- a/src/khoj/processor/markdown/markdown_to_entries.py
+++ b/src/khoj/processor/markdown/markdown_to_entries.py
@@ -47,6 +47,7 @@ class MarkdownToEntries(TextToEntries):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.MARKDOWN,
+ DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
diff --git a/src/khoj/processor/notion/notion_to_entries.py b/src/khoj/processor/notion/notion_to_entries.py
index a4b15d4e..7a88e2a1 100644
--- a/src/khoj/processor/notion/notion_to_entries.py
+++ b/src/khoj/processor/notion/notion_to_entries.py
@@ -250,7 +250,12 @@ class NotionToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
- current_entries, DbEntry.EntryType.NOTION, key="compiled", logger=logger, user=user
+ current_entries,
+ DbEntry.EntryType.NOTION,
+ DbEntry.EntrySource.NOTION,
+ key="compiled",
+ logger=logger,
+ user=user,
)
return num_new_embeddings, num_deleted_embeddings
diff --git a/src/khoj/processor/org_mode/org_to_entries.py b/src/khoj/processor/org_mode/org_to_entries.py
index bf6df6dc..04ce97e4 100644
--- a/src/khoj/processor/org_mode/org_to_entries.py
+++ b/src/khoj/processor/org_mode/org_to_entries.py
@@ -48,6 +48,7 @@ class OrgToEntries(TextToEntries):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.ORG,
+ DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
diff --git a/src/khoj/processor/pdf/pdf_to_entries.py b/src/khoj/processor/pdf/pdf_to_entries.py
index 81c2250f..3a47096a 100644
--- a/src/khoj/processor/pdf/pdf_to_entries.py
+++ b/src/khoj/processor/pdf/pdf_to_entries.py
@@ -46,6 +46,7 @@ class PdfToEntries(TextToEntries):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.PDF,
+ DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
diff --git a/src/khoj/processor/plaintext/plaintext_to_entries.py b/src/khoj/processor/plaintext/plaintext_to_entries.py
index fd5e1de2..d42dae30 100644
--- a/src/khoj/processor/plaintext/plaintext_to_entries.py
+++ b/src/khoj/processor/plaintext/plaintext_to_entries.py
@@ -56,6 +56,7 @@ class PlaintextToEntries(TextToEntries):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.PLAINTEXT,
+ DbEntry.EntrySource.COMPUTER,
key="compiled",
logger=logger,
deletion_filenames=deletion_file_names,
diff --git a/src/khoj/processor/text_to_entries.py b/src/khoj/processor/text_to_entries.py
index 501ef5d3..3d79e02e 100644
--- a/src/khoj/processor/text_to_entries.py
+++ b/src/khoj/processor/text_to_entries.py
@@ -78,6 +78,7 @@ class TextToEntries(ABC):
self,
current_entries: List[Entry],
file_type: str,
+ file_source: str,
key="compiled",
logger: logging.Logger = None,
deletion_filenames: Set[str] = None,
@@ -93,9 +94,9 @@ class TextToEntries(ABC):
num_deleted_entries = 0
if regenerate:
- with timer("Prepared dataset for regeneration in", logger):
+ with timer("Cleared existing dataset for regeneration in", logger):
logger.debug(f"Deleting all entries for file type {file_type}")
- num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type)
+ num_deleted_entries = EntryAdapters.delete_all_entries_by_type(user, file_type)
hashes_to_process = set()
with timer("Identified entries to add to database in", logger):
@@ -132,6 +133,7 @@ class TextToEntries(ABC):
compiled=entry.compiled,
heading=entry.heading[:1000], # Truncate to max chars of field allowed
file_path=entry.file,
+ file_source=file_source,
file_type=file_type,
hashed_value=entry_hash,
corpus_id=entry.corpus_id,
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index 84e63b09..81e805c6 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -23,7 +23,6 @@ from khoj.utils.rawconfig import (
FullConfig,
SearchConfig,
SearchResponse,
- TextContentConfig,
GithubContentConfig,
NotionContentConfig,
)
@@ -51,6 +50,7 @@ from database.models import (
LocalPdfConfig,
LocalPlaintextConfig,
KhojUser,
+ Entry as DbEntry,
GithubConfig,
NotionConfig,
)
@@ -61,11 +61,13 @@ api = APIRouter()
logger = logging.getLogger(__name__)
-def map_config_to_object(content_type: str):
- if content_type == "github":
+def map_config_to_object(content_source: str):
+ if content_source == DbEntry.EntrySource.GITHUB:
return GithubConfig
- if content_type == "notion":
+ if content_source == DbEntry.EntrySource.GITHUB:
return NotionConfig
+ if content_source == DbEntry.EntrySource.COMPUTER:
+ return "Computer"
async def map_config_to_db(config: FullConfig, user: KhojUser):
@@ -164,7 +166,7 @@ async def set_config_data(
return state.config
-@api.post("/config/data/content_type/github", status_code=200)
+@api.post("/config/data/content-source/github", status_code=200)
@requires(["authenticated"])
async def set_content_config_github_data(
request: Request,
@@ -192,7 +194,7 @@ async def set_content_config_github_data(
return {"status": "ok"}
-@api.post("/config/data/content_type/notion", status_code=200)
+@api.post("/config/data/content-source/notion", status_code=200)
@requires(["authenticated"])
async def set_content_config_notion_data(
request: Request,
@@ -219,11 +221,11 @@ async def set_content_config_notion_data(
return {"status": "ok"}
-@api.delete("/config/data/content_type/{content_type}", status_code=200)
+@api.delete("/config/data/content-source/{content_source}", status_code=200)
@requires(["authenticated"])
-async def remove_content_config_data(
+async def remove_content_source_data(
request: Request,
- content_type: str,
+ content_source: str,
client: Optional[str] = None,
):
user = request.user.object
@@ -233,15 +235,15 @@ async def remove_content_config_data(
telemetry_type="api",
api="delete_content_config",
client=client,
- metadata={"content_type": content_type},
+ metadata={"content_source": content_source},
)
- content_object = map_config_to_object(content_type)
+ content_object = map_config_to_object(content_source)
if content_object is None:
- raise ValueError(f"Invalid content type: {content_type}")
-
- await content_object.objects.filter(user=user).adelete()
- await sync_to_async(EntryAdapters.delete_all_entries)(user, content_type)
+ raise ValueError(f"Invalid content source: {content_source}")
+ elif content_object != "Computer":
+ await content_object.objects.filter(user=user).adelete()
+ await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source)
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
return {"status": "ok"}
@@ -268,10 +270,11 @@ async def remove_file_data(
return {"status": "ok"}
-@api.get("/config/data/all", response_model=List[str])
+@api.get("/config/data/{content_source}", response_model=List[str])
@requires(["authenticated"])
async def get_all_filenames(
request: Request,
+ content_source: str,
client: Optional[str] = None,
):
user = request.user.object
@@ -283,27 +286,7 @@ async def get_all_filenames(
client=client,
)
- return await sync_to_async(list)(EntryAdapters.aget_all_filenames(user))
-
-
-@api.delete("/config/data/all", status_code=200)
-@requires(["authenticated"])
-async def remove_all_config_data(
- request: Request,
- client: Optional[str] = None,
-):
- user = request.user.object
-
- update_telemetry_state(
- request=request,
- telemetry_type="api",
- api="delete_all_config",
- client=client,
- )
-
- await EntryAdapters.adelete_all_entries(user)
-
- return {"status": "ok"}
+ return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source))
@api.post("/config/data/conversation/model", status_code=200)
diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py
index ebabeb8e..4a3cbcef 100644
--- a/src/khoj/routers/auth.py
+++ b/src/khoj/routers/auth.py
@@ -24,7 +24,9 @@ logger = logging.getLogger(__name__)
auth_router = APIRouter()
if not state.anonymous_mode and not (os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET")):
- logger.info("Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET environment variables to use Google OAuth")
+ logger.warn(
+ "🚨 Use --anonymous-mode flag to disable Google OAuth or set GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET environment variables to enable it"
+ )
else:
config = Config(environ=os.environ)
diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py
index 1bbf53c2..a7a1249d 100644
--- a/src/khoj/routers/indexer.py
+++ b/src/khoj/routers/indexer.py
@@ -126,7 +126,7 @@ async def update(
# Extract required fields from config
loop = asyncio.get_event_loop()
- state.content_index = await loop.run_in_executor(
+ state.content_index, success = await loop.run_in_executor(
None,
configure_content,
state.content_index,
@@ -138,6 +138,8 @@ async def update(
False,
user,
)
+ if not success:
+ raise RuntimeError("Failed to update content index")
logger.info(f"Finished processing batch indexing request")
except Exception as e:
logger.error(f"Failed to process batch indexing request: {e}", exc_info=True)
@@ -145,6 +147,7 @@ async def update(
f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}",
exc_info=True,
)
+ return Response(content="Failed", status_code=500)
update_telemetry_state(
request=request,
@@ -182,18 +185,19 @@ def configure_content(
t: Optional[state.SearchType] = None,
full_corpus: bool = True,
user: KhojUser = None,
-) -> Optional[ContentIndex]:
+) -> tuple[Optional[ContentIndex], bool]:
content_index = ContentIndex()
+ success = True
if t is not None and not t.value in [type.value for type in state.SearchType]:
logger.warning(f"🚨 Invalid search type: {t}")
- return None
+ return None, False
search_type = t.value if t else None
if files is None:
logger.warning(f"🚨 No files to process for {search_type} search.")
- return None
+ return None, True
try:
# Initialize Org Notes Search
@@ -209,6 +213,7 @@ def configure_content(
)
except Exception as e:
logger.error(f"🚨 Failed to setup org: {e}", exc_info=True)
+ success = False
try:
# Initialize Markdown Search
@@ -225,6 +230,7 @@ def configure_content(
except Exception as e:
logger.error(f"🚨 Failed to setup markdown: {e}", exc_info=True)
+ success = False
try:
# Initialize PDF Search
@@ -241,6 +247,7 @@ def configure_content(
except Exception as e:
logger.error(f"🚨 Failed to setup PDF: {e}", exc_info=True)
+ success = False
try:
# Initialize Plaintext Search
@@ -257,6 +264,7 @@ def configure_content(
except Exception as e:
logger.error(f"🚨 Failed to setup plaintext: {e}", exc_info=True)
+ success = False
try:
# Initialize Image Search
@@ -274,6 +282,7 @@ def configure_content(
except Exception as e:
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
+ success = False
try:
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
@@ -291,6 +300,7 @@ def configure_content(
except Exception as e:
logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True)
+ success = False
try:
# Initialize Notion Search
@@ -308,12 +318,13 @@ def configure_content(
except Exception as e:
logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True)
+ success = False
# Invalidate Query Cache
if user:
state.query_cache[user.uuid] = LRU()
- return content_index
+ return content_index, success
def load_content(
diff --git a/src/khoj/routers/subscription.py b/src/khoj/routers/subscription.py
new file mode 100644
index 00000000..3457b671
--- /dev/null
+++ b/src/khoj/routers/subscription.py
@@ -0,0 +1,106 @@
+# Standard Packages
+from datetime import datetime, timezone
+import logging
+import os
+
+# External Packages
+from asgiref.sync import sync_to_async
+from fastapi import APIRouter, Request
+from starlette.authentication import requires
+import stripe
+
+# Internal Packages
+from database import adapters
+
+
+# Stripe integration for Khoj Cloud Subscription
+stripe.api_key = os.getenv("STRIPE_API_KEY")
+endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET")
+logger = logging.getLogger(__name__)
+subscription_router = APIRouter()
+
+
+@subscription_router.post("")
+async def subscribe(request: Request):
+ """Webhook for Stripe to send subscription events to Khoj Cloud"""
+ event = None
+ try:
+ payload = await request.body()
+ sig_header = request.headers["stripe-signature"]
+ event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret)
+ except ValueError as e:
+ # Invalid payload
+ raise e
+ except stripe.error.SignatureVerificationError as e:
+ # Invalid signature
+ raise e
+
+ event_type = event["type"]
+ if event_type not in {
+ "invoice.paid",
+ "customer.subscription.updated",
+ "customer.subscription.deleted",
+ "subscription_schedule.canceled",
+ }:
+ logger.warn(f"Unhandled Stripe event type: {event['type']}")
+ return {"success": False}
+
+ # Retrieve the customer's details
+ subscription = event["data"]["object"]
+ customer_id = subscription["customer"]
+ customer = stripe.Customer.retrieve(customer_id)
+ customer_email = customer["email"]
+
+ # Handle valid stripe webhook events
+ success = True
+ if event_type in {"invoice.paid"}:
+ # Mark the user as subscribed and update the next renewal date on payment
+ subscription = stripe.Subscription.list(customer=customer_id).data[0]
+ renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
+ user = await adapters.set_user_subscription(customer_email, is_recurring=True, renewal_date=renewal_date)
+ success = user is not None
+ elif event_type in {"customer.subscription.updated"}:
+ user_subscription = await sync_to_async(adapters.get_user_subscription)(customer_email)
+ # Allow updating subscription status if paid user
+ if user_subscription and user_subscription.renewal_date:
+ # Mark user as unsubscribed or resubscribed
+ is_recurring = not subscription["cancel_at_period_end"]
+ updated_user = await adapters.set_user_subscription(customer_email, is_recurring=is_recurring)
+ success = updated_user is not None
+ elif event_type in {"customer.subscription.deleted"}:
+ # Reset the user to trial state
+ user = await adapters.set_user_subscription(
+ customer_email, is_recurring=False, renewal_date=False, type="trial"
+ )
+ success = user is not None
+
+ logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
+ return {"success": success}
+
+
+@subscription_router.patch("")
+@requires(["authenticated"])
+async def update_subscription(request: Request, email: str, operation: str):
+ # Retrieve the customer's details
+ customers = stripe.Customer.list(email=email).auto_paging_iter()
+ customer = next(customers, None)
+ if customer is None:
+ return {"success": False, "message": "Customer not found"}
+
+ if operation == "cancel":
+ customer_id = customer.id
+ for subscription in stripe.Subscription.list(customer=customer_id):
+ stripe.Subscription.modify(subscription.id, cancel_at_period_end=True)
+ return {"success": True}
+
+ elif operation == "resubscribe":
+ subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter()
+ # Find the subscription that is set to cancel at the end of the period
+ for subscription in subscriptions:
+ if subscription.cancel_at_period_end:
+ # Update the subscription to not cancel at the end of the period
+ stripe.Subscription.modify(subscription.id, cancel_at_period_end=False)
+ return {"success": True}
+ return {"success": False, "message": "No subscription found that is set to cancel"}
+
+ return {"success": False, "message": "Invalid operation"}
diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py
index 35603e18..229cee64 100644
--- a/src/khoj/routers/web_client.py
+++ b/src/khoj/routers/web_client.py
@@ -8,8 +8,9 @@ from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from starlette.authentication import requires
+from database import adapters
+from database.models import KhojUser
from khoj.utils.rawconfig import (
- TextContentConfig,
GithubContentConfig,
GithubRepoConfig,
NotionContentConfig,
@@ -17,15 +18,18 @@ from khoj.utils.rawconfig import (
# Internal Packages
from khoj.utils import constants, state
-from database.adapters import EntryAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters
-from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
+from database.adapters import (
+ EntryAdapters,
+ get_user_github_config,
+ get_user_notion_config,
+ ConversationAdapters,
+ get_user_subscription_state,
+)
# Initialize Router
web_client = APIRouter()
templates = Jinja2Templates(directory=constants.web_directory)
-VALID_TEXT_CONTENT_TYPES = ["org", "markdown", "pdf", "plaintext"]
-
# Create Routes
@web_client.get("/", response_class=FileResponse)
@@ -109,41 +113,26 @@ def login_page(request: Request):
)
-def map_config_to_object(content_type: str):
- if content_type == "org":
- return LocalOrgConfig
- if content_type == "markdown":
- return LocalMarkdownConfig
- if content_type == "pdf":
- return LocalPdfConfig
- if content_type == "plaintext":
- return LocalPlaintextConfig
-
-
@web_client.get("/config", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def config_page(request: Request):
- user = request.user.object
+ user: KhojUser = request.user.object
user_picture = request.session.get("user", {}).get("picture")
- enabled_content = set(EntryAdapters.get_unique_file_types(user).all())
+ user_subscription = adapters.get_user_subscription(user.email)
+ user_subscription_state = get_user_subscription_state(user_subscription)
+ subscription_renewal_date = (
+ user_subscription.renewal_date.strftime("%d %b %Y")
+ if user_subscription and user_subscription.renewal_date
+ else None
+ )
+ enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all())
successfully_configured = {
- "pdf": ("pdf" in enabled_content),
- "markdown": ("markdown" in enabled_content),
- "org": ("org" in enabled_content),
- "image": False,
- "github": ("github" in enabled_content),
- "notion": ("notion" in enabled_content),
- "plaintext": ("plaintext" in enabled_content),
+ "computer": ("computer" in enabled_content_source),
+ "github": ("github" in enabled_content_source),
+ "notion": ("notion" in enabled_content_source),
}
- if state.content_index:
- successfully_configured.update(
- {
- "image": state.content_index.image is not None,
- }
- )
-
conversation_options = ConversationAdapters.get_conversation_processor_options().all()
all_conversation_options = list()
for conversation_option in conversation_options:
@@ -157,15 +146,19 @@ def config_page(request: Request):
"request": request,
"current_model_state": successfully_configured,
"anonymous_mode": state.anonymous_mode,
- "username": user.username if user else None,
+ "username": user.username,
"conversation_options": all_conversation_options,
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
"user_photo": user_picture,
+ "billing_enabled": state.billing_enabled,
+ "subscription_state": user_subscription_state,
+ "subscription_renewal_date": subscription_renewal_date,
+ "khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
},
)
-@web_client.get("/config/content_type/github", response_class=HTMLResponse)
+@web_client.get("/config/content-source/github", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def github_config_page(request: Request):
user = request.user.object
@@ -192,7 +185,7 @@ def github_config_page(request: Request):
current_config = {} # type: ignore
return templates.TemplateResponse(
- "content_type_github_input.html",
+ "content_source_github_input.html",
context={
"request": request,
"current_config": current_config,
@@ -202,7 +195,7 @@ def github_config_page(request: Request):
)
-@web_client.get("/config/content_type/notion", response_class=HTMLResponse)
+@web_client.get("/config/content-source/notion", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def notion_config_page(request: Request):
user = request.user.object
@@ -216,7 +209,7 @@ def notion_config_page(request: Request):
current_config = json.loads(current_config.json())
return templates.TemplateResponse(
- "content_type_notion_input.html",
+ "content_source_notion_input.html",
context={
"request": request,
"current_config": current_config,
@@ -226,32 +219,16 @@ def notion_config_page(request: Request):
)
-@web_client.get("/config/content_type/{content_type}", response_class=HTMLResponse)
+@web_client.get("/config/content-source/computer", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
-def content_config_page(request: Request, content_type: str):
- if content_type not in VALID_TEXT_CONTENT_TYPES:
- return templates.TemplateResponse("config.html", context={"request": request})
-
- object = map_config_to_object(content_type)
+def computer_config_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
- config = object.objects.filter(user=user).first()
- if config == None:
- config = object.objects.create(user=user)
-
- current_config = TextContentConfig(
- input_files=config.input_files,
- input_filter=config.input_filter,
- index_heading_entries=config.index_heading_entries,
- )
- current_config = json.loads(current_config.json())
return templates.TemplateResponse(
- "content_type_input.html",
+ "content_source_computer_input.html",
context={
"request": request,
- "current_config": current_config,
- "content_type": content_type,
"username": user.username,
"user_photo": user_picture,
},
diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py
index 14f5b770..ba2fc9ec 100644
--- a/src/khoj/search_type/text_search.py
+++ b/src/khoj/search_type/text_search.py
@@ -204,11 +204,12 @@ def setup(
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
)
- file_names = [file_name for file_name in files]
+ if files:
+ file_names = [file_name for file_name in files]
- logger.info(
- f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
- )
+ logger.info(
+ f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
+ )
def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]:
diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py
index 748ca15a..098ae35e 100644
--- a/src/khoj/utils/state.py
+++ b/src/khoj/utils/state.py
@@ -1,4 +1,5 @@
# Standard Packages
+import os
import threading
from typing import List, Dict
from collections import defaultdict
@@ -35,3 +36,8 @@ khoj_version: str = None
device = get_device()
chat_on_gpu: bool = True
anonymous_mode: bool = False
+billing_enabled: bool = (
+ os.getenv("STRIPE_API_KEY") is not None
+ and os.getenv("STRIPE_SIGNING_SECRET") is not None
+ and os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL") is not None
+)
diff --git a/tests/conftest.py b/tests/conftest.py
index fbb98476..59104123 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -196,7 +196,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser):
# Index Markdown Content for Search
all_files = fs_syncer.collect_files(user=default_user2)
- state.content_index = configure_content(
+ state.content_index, _ = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
)
diff --git a/tests/test_helpers.py b/tests/test_helpers.py
index 30499049..fdd29b02 100644
--- a/tests/test_helpers.py
+++ b/tests/test_helpers.py
@@ -64,6 +64,7 @@ def test_encode_docs_memory_leak():
batch_size = 20
embeddings_model = EmbeddingsModel()
memory_usage_trend = []
+ device = f"{helpers.get_device()}".upper()
# Act
# Encode random strings repeatedly and record memory usage trend
@@ -76,8 +77,9 @@ def test_encode_docs_memory_leak():
# Calculate slope of line fitting memory usage history
memory_usage_trend = np.array(memory_usage_trend)
slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend)
+ print(f"Memory usage increased at ~{slope:.2f} MB per iteration on {device}")
# Assert
# If slope is positive memory utilization is increasing
# Positive threshold of 2, from observing memory usage trend on MPS vs CPU device
- assert slope < 2, f"Memory usage increasing at ~{slope:.2f} MB per iteration"
+ assert slope < 2, f"Memory leak suspected on {device}. Memory usage increased at ~{slope:.2f} MB per iteration"
diff --git a/tests/test_text_search.py b/tests/test_text_search.py
index 7d8c30fb..3d729ab5 100644
--- a/tests/test_text_search.py
+++ b/tests/test_text_search.py
@@ -58,7 +58,7 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
-def test_text_search_setup_with_empty_file_raises_error(
+def test_text_search_setup_with_empty_file_creates_no_entries(
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
):
# Arrange