From 54a387326cef113f407e7678dec5be30cc2bb4bf Mon Sep 17 00:00:00 2001 From: sabaimran <65192171+sabaimran@users.noreply.github.com> Date: Tue, 31 Oct 2023 17:59:53 -0700 Subject: [PATCH] [Multi-User Part 6]: Address small bugs and upstream PR comments (#518) - 08654163cb227edb8991ad7f77c99b560819f4f9: Add better parsing for XML files - f3acfac7fbec0a3876e7586607c376c9dfac6a4d: Add a try/catch around the dateparser in order to avoid internal server errors in app - 7d43cd62c0d51889978413ca411cec1bd37b024a: Chunk embeddings generation in order to avoid large memory load - e02d751eb3cb9a005f1b529fde3b34ce3c026f1b: Addresses comments from PR #498 - a3f393edb49842bedb4f42fba2dfc831ee51db7f: Addresses comments from PR #503 - 66eb0782867b201a878c7fb13ba662be1258037c: Addresses comments from PR #511 - Address various items in https://github.com/khoj-ai/khoj/issues/527 --- pyproject.toml | 1 + src/database/adapters/__init__.py | 24 +++--- src/interface/desktop/config.html | 2 +- src/interface/desktop/index.html | 11 +++ src/interface/desktop/main.js | 6 ++ src/interface/desktop/renderer.js | 1 + src/khoj/configure.py | 9 +-- .../interface/web/assets/icons/copy-solid.svg | 1 + .../web/assets/icons/trash-solid.svg | 1 + src/khoj/interface/web/base_config.html | 22 +++++- src/khoj/interface/web/config.html | 68 +++++++++++++--- src/khoj/interface/web/index.html | 11 +++ src/khoj/interface/web/login.html | 2 +- .../processor/plaintext/plaintext_to_jsonl.py | 28 ++++++- src/khoj/processor/text_to_jsonl.py | 78 ++++++++++--------- src/khoj/routers/api.py | 2 +- src/khoj/routers/auth.py | 1 - src/khoj/routers/helpers.py | 2 +- src/khoj/routers/indexer.py | 7 +- src/khoj/search_filter/date_filter.py | 20 +++-- src/khoj/utils/helpers.py | 11 +++ tests/test_text_search.py | 70 +++++++++++------ 22 files changed, 264 insertions(+), 114 deletions(-) create mode 100644 src/khoj/interface/web/assets/icons/copy-solid.svg create mode 100644 src/khoj/interface/web/assets/icons/trash-solid.svg diff --git a/pyproject.toml b/pyproject.toml index d5b7f0ce..f4ae57f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ dependencies = [ "google-auth == 2.23.3", "python-multipart == 0.0.6", "gunicorn == 21.2.0", + "lxml == 4.9.3", ] dynamic = ["version"] diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index 52debdc4..362398d8 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -1,6 +1,9 @@ import secrets from typing import Type, TypeVar, List from datetime import date +import secrets +from typing import Type, TypeVar, List +from datetime import date from django.db import models from django.contrib.sessions.backends.db import SessionStore @@ -8,6 +11,10 @@ from pgvector.django import CosineDistance from django.db.models.manager import BaseManager from django.db.models import Q from torch import Tensor +from pgvector.django import CosineDistance +from django.db.models.manager import BaseManager +from django.db.models import Q +from torch import Tensor # Import sync_to_async from Django Channels from asgiref.sync import sync_to_async @@ -58,7 +65,7 @@ async def set_notion_config(token: str, user: KhojUser): async def create_khoj_token(user: KhojUser, name=None): "Create Khoj API key for user" token = f"kk-{secrets.token_urlsafe(32)}" - name = name or f"{generate_random_name().title()}'s Secret Key" + name = name or f"{generate_random_name().title()}" api_config = await KhojApiUser.objects.acreate(token=token, user=user, name=name) await api_config.asave() return api_config @@ -123,15 +130,11 @@ def get_all_users() -> BaseManager[KhojUser]: def get_user_github_config(user: KhojUser): config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() - if not config: - return None return config def get_user_notion_config(user: KhojUser): config = NotionConfig.objects.filter(user=user).first() - if not config: - return None return config @@ -240,13 +243,10 @@ class ConversationAdapters: @staticmethod def get_enabled_conversation_settings(user: KhojUser): openai_config = ConversationAdapters.get_openai_conversation_config(user) - offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user) return { "openai": True if openai_config is not None else False, - "offline_chat": True - if (offline_chat_config is not None and offline_chat_config.enable_offline_chat) - else False, + "offline_chat": ConversationAdapters.has_offline_chat(user), } @staticmethod @@ -264,7 +264,11 @@ class ConversationAdapters: OfflineChatProcessorConversationConfig.objects.filter(user=user).delete() @staticmethod - async def has_offline_chat(user: KhojUser): + def has_offline_chat(user: KhojUser): + return OfflineChatProcessorConversationConfig.objects.filter(user=user, enable_offline_chat=True).exists() + + @staticmethod + async def ahas_offline_chat(user: KhojUser): return await OfflineChatProcessorConversationConfig.objects.filter( user=user, enable_offline_chat=True ).aexists() diff --git a/src/interface/desktop/config.html b/src/interface/desktop/config.html index 4b79f1a1..b781af26 100644 --- a/src/interface/desktop/config.html +++ b/src/interface/desktop/config.html @@ -2,7 +2,7 @@ - Khoj - Search + Khoj - Settings diff --git a/src/interface/desktop/index.html b/src/interface/desktop/index.html index 283f2477..ce930cec 100644 --- a/src/interface/desktop/index.html +++ b/src/interface/desktop/index.html @@ -94,6 +94,15 @@ }).join("\n"); } + function render_xml(query, data) { + return data.map(function (item) { + return `
` + + `${item.additional.heading}` + + `${item.entry}` + + `
` + }).join("\n"); + } + function render_multiple(query, data, type) { let html = ""; data.forEach(item => { @@ -113,6 +122,8 @@ html += `
` + `${item.additional.heading}` + `

${item.entry}

` + `
`; } else if (item.additional.file.endsWith(".html")) { html += render_html(query, [item]); + } else if (item.additional.file.endsWith(".xml")) { + html += render_xml(query, [item]) } else { html += `
` + `${item.additional.heading}` + `

${item.entry}

` + `
`; } diff --git a/src/interface/desktop/main.js b/src/interface/desktop/main.js index 7c0559c3..d38a9e9b 100644 --- a/src/interface/desktop/main.js +++ b/src/interface/desktop/main.js @@ -267,6 +267,12 @@ async function getFolders () { } async function setURL (event, url) { + // Sanitize the URL. Remove trailing slash if present. Add http:// if not present. + url = url.replace(/\/$/, ""); + if (!url.match(/^[a-zA-Z]+:\/\//)) { + url = `http://${url}`; + } + store.set('hostURL', url); return store.get('hostURL'); } diff --git a/src/interface/desktop/renderer.js b/src/interface/desktop/renderer.js index 5586758c..b365ceff 100644 --- a/src/interface/desktop/renderer.js +++ b/src/interface/desktop/renderer.js @@ -174,6 +174,7 @@ urlInput.addEventListener('blur', async () => { new URL(urlInputValue); } catch (e) { console.log(e); + alert('Please enter a valid URL'); return; } diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 67ca3543..5dec86e7 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -25,7 +25,6 @@ from khoj.utils import constants, state from khoj.utils.config import ( SearchType, ) -from khoj.utils.helpers import merge_dicts from khoj.utils.fs_syncer import collect_files from khoj.utils.rawconfig import FullConfig from khoj.routers.indexer import configure_content, load_content, configure_search @@ -83,12 +82,6 @@ class UserAuthenticationBackend(AuthenticationBackend): def initialize_server(config: Optional[FullConfig]): - if config is None: - logger.warning( - f"🚨 Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config, plugins or by editing {state.config_file}." - ) - return None - try: configure_server(config, init=True) except Exception as e: @@ -190,7 +183,7 @@ def configure_search_types(config: FullConfig): core_search_types = {e.name: e.value for e in SearchType} # Dynamically generate search type enum by merging core search types with configured plugin search types - return Enum("SearchType", merge_dicts(core_search_types, {})) + return Enum("SearchType", core_search_types) @schedule.repeat(schedule.every(59).minutes) diff --git a/src/khoj/interface/web/assets/icons/copy-solid.svg b/src/khoj/interface/web/assets/icons/copy-solid.svg new file mode 100644 index 00000000..da7020be --- /dev/null +++ b/src/khoj/interface/web/assets/icons/copy-solid.svg @@ -0,0 +1 @@ + diff --git a/src/khoj/interface/web/assets/icons/trash-solid.svg b/src/khoj/interface/web/assets/icons/trash-solid.svg new file mode 100644 index 00000000..768d80f8 --- /dev/null +++ b/src/khoj/interface/web/assets/icons/trash-solid.svg @@ -0,0 +1 @@ + diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html index 15c3f678..db77787b 100644 --- a/src/khoj/interface/web/base_config.html +++ b/src/khoj/interface/web/base_config.html @@ -52,10 +52,24 @@ justify-self: center; } - - div.section.general-settings { - justify-self: center; - } + .api-settings { + display: grid; + grid-template-columns: 1fr; + grid-template-rows: 1fr 1fr auto; + justify-items: start; + gap: 8px; + padding: 24px 24px; + background: white; + border: 1px solid rgb(229, 229, 229); + border-radius: 4px; + box-shadow: 0px 1px 3px 0px rgba(0,0,0,0.1),0px 1px 2px -1px rgba(0,0,0,0.1); + } + #api-settings-card-description { + margin: 8px 0 0 0; + } + #api-settings-keys-table { + margin-bottom: 16px; + } div.instructions { font-size: large; diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 3f504efa..c65615e1 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -287,11 +287,33 @@ -
-
- -
+
+
+
+ API Key +

API Keys

+
+
+

Manage access to your Khoj from client apps

+
+ + + + + + + + + +
NameKeyActions
+
+ +
+
+
@@ -520,14 +542,32 @@ .then(response => response.json()) .then(tokenObj => { apiKeyList.innerHTML += ` -
- ${tokenObj.token} - -
+ + ${tokenObj.name} + ${tokenObj.token} + + Copy API Key + Delete API Key + + `; }); } + function copyAPIKey(token) { + // Copy API key to clipboard + navigator.clipboard.writeText(token); + // Flash the API key copied message + const copyApiKeyButton = document.getElementById(`api-key-${token}`); + original_html = copyApiKeyButton.innerHTML + setTimeout(function() { + copyApiKeyButton.innerHTML = "✅ Copied to your clipboard!"; + setTimeout(function() { + copyApiKeyButton.innerHTML = original_html; + }, 1000); + }, 100); + } + function deleteAPIKey(token) { const apiKeyList = document.getElementById("api-key-list"); fetch(`/auth/token?token=${token}`, { @@ -548,10 +588,14 @@ .then(tokens => { apiKeyList.innerHTML = tokens.map(tokenObj => ` -
- ${tokenObj.token} - -
+ + ${tokenObj.name} + ${tokenObj.token} + + Copy API Key + Delete API Key + + `) .join(""); }); diff --git a/src/khoj/interface/web/index.html b/src/khoj/interface/web/index.html index ccf1ca71..539c96e0 100644 --- a/src/khoj/interface/web/index.html +++ b/src/khoj/interface/web/index.html @@ -94,6 +94,15 @@ }).join("\n"); } + function render_xml(query, data) { + return data.map(function (item) { + return `
` + + `${item.additional.heading}` + + `${item.entry}` + + `
` + }).join("\n"); + } + function render_multiple(query, data, type) { let html = ""; data.forEach(item => { @@ -113,6 +122,8 @@ html += `
` + `${item.additional.heading}` + `

${item.entry}

` + `
`; } else if (item.additional.file.endsWith(".html")) { html += render_html(query, [item]); + } else if (item.additional.file.endsWith(".xml")) { + html += render_xml(query, [item]) } else { html += `
` + `${item.additional.heading}` + `

${item.entry}

` + `
`; } diff --git a/src/khoj/interface/web/login.html b/src/khoj/interface/web/login.html index 550991ed..1bab4221 100644 --- a/src/khoj/interface/web/login.html +++ b/src/khoj/interface/web/login.html @@ -2,7 +2,7 @@ - Khoj - Search + Khoj - Login diff --git a/src/khoj/processor/plaintext/plaintext_to_jsonl.py b/src/khoj/processor/plaintext/plaintext_to_jsonl.py index 965a5a7b..086808b7 100644 --- a/src/khoj/processor/plaintext/plaintext_to_jsonl.py +++ b/src/khoj/processor/plaintext/plaintext_to_jsonl.py @@ -2,12 +2,14 @@ import logging from pathlib import Path from typing import List, Tuple +from bs4 import BeautifulSoup + # Internal Packages from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.utils.helpers import timer -from khoj.utils.rawconfig import Entry, TextContentConfig -from database.models import Embeddings, KhojUser, LocalPlaintextConfig +from khoj.utils.rawconfig import Entry +from database.models import Embeddings, KhojUser logger = logging.getLogger(__name__) @@ -28,6 +30,19 @@ class PlaintextToJsonl(TextEmbeddings): else: deletion_file_names = None + with timer("Scrub plaintext files and extract text", logger): + for file in files: + try: + plaintext_content = files[file] + if file.endswith(("html", "htm", "xml")): + plaintext_content = PlaintextToJsonl.extract_html_content( + plaintext_content, file.split(".")[-1] + ) + files[file] = plaintext_content + except Exception as e: + logger.warning(f"Unable to read file: {file} as plaintext. Skipping file.") + logger.warning(e, exc_info=True) + # Extract Entries from specified plaintext files with timer("Parse entries from plaintext files", logger): current_entries = PlaintextToJsonl.convert_plaintext_entries_to_maps(files) @@ -50,6 +65,15 @@ class PlaintextToJsonl(TextEmbeddings): return num_new_embeddings, num_deleted_embeddings + @staticmethod + def extract_html_content(markup_content: str, markup_type: str): + "Extract content from HTML" + if markup_type == "xml": + soup = BeautifulSoup(markup_content, "xml") + else: + soup = BeautifulSoup(markup_content, "html.parser") + return soup.get_text(strip=True, separator="\n") + @staticmethod def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]: "Convert each plaintext entries into a dictionary" diff --git a/src/khoj/processor/text_to_jsonl.py b/src/khoj/processor/text_to_jsonl.py index c83c83b1..831a032f 100644 --- a/src/khoj/processor/text_to_jsonl.py +++ b/src/khoj/processor/text_to_jsonl.py @@ -5,7 +5,7 @@ import logging import uuid from tqdm import tqdm from typing import Callable, List, Tuple, Set, Any -from khoj.utils.helpers import timer +from khoj.utils.helpers import timer, batcher # Internal Packages @@ -93,7 +93,7 @@ class TextEmbeddings(ABC): num_deleted_embeddings = 0 with timer("Preparing dataset for regeneration", logger): if regenerate: - logger.info(f"Deleting all embeddings for file type {file_type}") + logger.debug(f"Deleting all embeddings for file type {file_type}") num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type) num_new_embeddings = 0 @@ -106,48 +106,54 @@ class TextEmbeddings(ABC): ) existing_entry_hashes = set([entry.hashed_value for entry in existing_entries]) hashes_to_process = hashes_for_file - existing_entry_hashes - # for hashed_val in hashes_for_file: - # if not EmbeddingsAdapters.does_embedding_exist(user, hashed_val): - # hashes_to_process.add(hashed_val) entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process] data_to_embed = [getattr(entry, key) for entry in entries_to_process] embeddings = self.embeddings_model.embed_documents(data_to_embed) with timer("Update the database with new vector embeddings", logger): - embeddings_to_create = [] - for hashed_val, embedding in zip(hashes_to_process, embeddings): - entry = hash_to_current_entries[hashed_val] - embeddings_to_create.append( - Embeddings( - user=user, - embeddings=embedding, - raw=entry.raw, - compiled=entry.compiled, - heading=entry.heading, - file_path=entry.file, - file_type=file_type, - hashed_value=hashed_val, - corpus_id=entry.corpus_id, - ) - ) - new_embeddings = Embeddings.objects.bulk_create(embeddings_to_create) - num_new_embeddings += len(new_embeddings) + num_items = len(hashes_to_process) + assert num_items == len(embeddings) + batch_size = min(200, num_items) + entry_batches = zip(hashes_to_process, embeddings) - dates_to_create = [] - with timer("Create new date associations for new embeddings", logger): - for embedding in new_embeddings: - dates = self.date_filter.extract_dates(embedding.raw) - for date in dates: - dates_to_create.append( - EmbeddingsDates( - date=date, - embeddings=embedding, - ) + for entry_batch in tqdm( + batcher(entry_batches, batch_size), desc="Processing embeddings in batches" + ): + batch_embeddings_to_create = [] + for entry_hash, embedding in entry_batch: + entry = hash_to_current_entries[entry_hash] + batch_embeddings_to_create.append( + Embeddings( + user=user, + embeddings=embedding, + raw=entry.raw, + compiled=entry.compiled, + heading=entry.heading[:1000], # Truncate to max chars of field allowed + file_path=entry.file, + file_type=file_type, + hashed_value=entry_hash, + corpus_id=entry.corpus_id, ) - new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create) - if len(new_dates) > 0: - logger.info(f"Created {len(new_dates)} new date entries") + ) + new_embeddings = Embeddings.objects.bulk_create(batch_embeddings_to_create) + logger.debug(f"Created {len(new_embeddings)} new embeddings") + num_new_embeddings += len(new_embeddings) + + dates_to_create = [] + with timer("Create new date associations for new embeddings", logger): + for embedding in new_embeddings: + dates = self.date_filter.extract_dates(embedding.raw) + for date in dates: + dates_to_create.append( + EmbeddingsDates( + date=date, + embeddings=embedding, + ) + ) + new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create) + if len(new_dates) > 0: + logger.debug(f"Created {len(new_dates)} new date entries") with timer("Identify hashes for removed entries", logger): for file in hashes_by_file: diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 4984ea4c..2607120d 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -723,7 +723,7 @@ async def extract_references_and_questions( # Infer search queries from user message with timer("Extracting search queries took", logger): # If we've reached here, either the user has enabled offline chat or the openai model is enabled. - if await ConversationAdapters.has_offline_chat(user): + if await ConversationAdapters.ahas_offline_chat(user): using_offline_chat = True offline_chat = await ConversationAdapters.get_offline_chat(user) chat_model = offline_chat.chat_model diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py index 5c375bd0..ebabeb8e 100644 --- a/src/khoj/routers/auth.py +++ b/src/khoj/routers/auth.py @@ -46,7 +46,6 @@ async def login(request: Request): return await oauth.google.authorize_redirect(request, redirect_uri) -@auth_router.post("/redirect") @auth_router.post("/token") @requires(["authenticated"], redirect="login_page") async def generate_token(request: Request, token_name: Optional[str] = None) -> str: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 8a9e53a7..185217ed 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -31,7 +31,7 @@ def perform_chat_checks(user: KhojUser): async def is_ready_to_chat(user: KhojUser): - has_offline_config = await ConversationAdapters.has_offline_chat(user=user) + has_offline_config = await ConversationAdapters.ahas_offline_chat(user=user) has_openai_config = await ConversationAdapters.has_openai_chat(user=user) if has_offline_config: diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py index e7df65a2..590164fb 100644 --- a/src/khoj/routers/indexer.py +++ b/src/khoj/routers/indexer.py @@ -156,18 +156,17 @@ async def update( host=host, ) + logger.info(f"📪 Content index updated via API call by {client} client") + return Response(content="OK", status_code=200) def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]: # Run Validation Checks - if search_config is None: - logger.warning("🚨 No Search configuration available.") - return None if search_models is None: search_models = SearchModels() - if search_config.image: + if search_config and search_config.image: logger.info("🔍 🌄 Setting up image search model") search_models.image_search = image_search.initialize_model(search_config.image) diff --git a/src/khoj/search_filter/date_filter.py b/src/khoj/search_filter/date_filter.py index 88c70101..1d90b9f5 100644 --- a/src/khoj/search_filter/date_filter.py +++ b/src/khoj/search_filter/date_filter.py @@ -127,14 +127,18 @@ class DateFilter(BaseFilter): clean_date_str = re.sub("|".join(future_strings), "", date_str) # parse date passed in query date filter - parsed_date = dtparse.parse( - clean_date_str, - settings={ - "RELATIVE_BASE": relative_base or datetime.now(), - "PREFER_DAY_OF_MONTH": "first", - "PREFER_DATES_FROM": prefer_dates_from, - }, - ) + try: + parsed_date = dtparse.parse( + clean_date_str, + settings={ + "RELATIVE_BASE": relative_base or datetime.now(), + "PREFER_DAY_OF_MONTH": "first", + "PREFER_DATES_FROM": prefer_dates_from, + }, + ) + except Exception as e: + logger.error(f"Failed to parse date string: {date_str} with error: {e}") + return None if parsed_date is None: return None diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index f6418fbd..3bce67a0 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -5,6 +5,7 @@ import datetime from enum import Enum from importlib import import_module from importlib.metadata import version +from itertools import islice import logging from os import path import os @@ -305,3 +306,13 @@ def generate_random_name(): name = f"{adjective} {noun}" return name + + +def batcher(iterable, max_n): + "Split an iterable into chunks of size max_n" + it = iter(iterable) + while True: + chunk = list(islice(it, max_n)) + if not chunk: + return + yield (x for x in chunk if x is not None) diff --git a/tests/test_text_search.py b/tests/test_text_search.py index aeeaa85f..ae7f0c20 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -22,9 +22,7 @@ logger = logging.getLogger(__name__) # Test # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_text_search_setup_with_missing_file_raises_error( - org_config_with_only_new_file: LocalOrgConfig, search_config: SearchConfig -): +def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: LocalOrgConfig): # Arrange # Ensure file mentioned in org.input-files is missing single_new_file = Path(org_config_with_only_new_file.input_files[0]) @@ -70,22 +68,39 @@ def test_text_search_setup_with_empty_file_raises_error( with caplog.at_level(logging.INFO): text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) - assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message + assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message verify_embeddings(0, default_user) # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_text_search_setup(content_config, default_user: KhojUser, caplog): +def test_text_indexer_deletes_embedding_before_regenerate( + content_config: ContentConfig, default_user: KhojUser, caplog +): # Arrange org_config = LocalOrgConfig.objects.filter(user=default_user).first() data = get_org_files(org_config) - with caplog.at_level(logging.INFO): + with caplog.at_level(logging.DEBUG): text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) # Assert - assert "Deleting all embeddings for file type org" in caplog.records[1].message - assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message + assert "Deleting all embeddings for file type org" in caplog.text + assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db +def test_text_search_setup_batch_processes(content_config: ContentConfig, default_user: KhojUser, caplog): + # Arrange + org_config = LocalOrgConfig.objects.filter(user=default_user).first() + data = get_org_files(org_config) + with caplog.at_level(logging.DEBUG): + text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) + + # Assert + assert "Created 4 new embeddings" in caplog.text + assert "Created 6 new embeddings" in caplog.text + assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message # ---------------------------------------------------------------------------------------------------- @@ -97,13 +112,13 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def # Act # Generate initial notes embeddings during asymmetric setup - with caplog.at_level(logging.INFO): + with caplog.at_level(logging.DEBUG): text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) initial_logs = caplog.text caplog.clear() # Clear logs # Run asymmetric setup again with no changes to data source. Ensure index is not updated - with caplog.at_level(logging.INFO): + with caplog.at_level(logging.DEBUG): text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user) final_logs = caplog.text @@ -175,12 +190,10 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon # Assert # verify newly added org-mode entry is split by max tokens - record = caplog.records[1] - assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message + assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message # ---------------------------------------------------------------------------------------------------- -# @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests") @pytest.mark.django_db def test_entry_chunking_by_max_tokens_not_full_corpus( org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog @@ -232,11 +245,9 @@ conda activate khoj user=default_user, ) - record = caplog.records[1] - # Assert # verify newly added org-mode entry is split by max tokens - assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message + assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message # ---------------------------------------------------------------------------------------------------- @@ -251,7 +262,7 @@ def test_regenerate_index_with_new_entry( with caplog.at_level(logging.INFO): text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) - assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message + assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message # append org-mode entry to first org input file in config org_config.input_files = [f"{new_org_file}"] @@ -286,20 +297,23 @@ def test_update_index_with_duplicate_entries_in_stable_order( data = get_org_files(org_config_with_only_new_file) # Act - # load embeddings, entries, notes model after adding new org-mode file + # generate embeddings, entries, notes model from scratch after adding new org-mode file with caplog.at_level(logging.INFO): text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) + initial_logs = caplog.text + caplog.clear() # Clear logs data = get_org_files(org_config_with_only_new_file) - # update embeddings, entries, notes model after adding new org-mode file + # update embeddings, entries, notes model with no new changes with caplog.at_level(logging.INFO): text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user) + final_logs = caplog.text # Assert # verify only 1 entry added even if there are multiple duplicate entries - assert "Created 1 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message - assert "Created 0 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message + assert "Created 1 new embeddings. Deleted 3 embeddings for user " in initial_logs + assert "Created 0 new embeddings. Deleted 0 embeddings for user " in final_logs verify_embeddings(1, default_user) @@ -319,6 +333,8 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg # load embeddings, entries, notes model after adding new org file with 2 entries with caplog.at_level(logging.INFO): text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) + initial_logs = caplog.text + caplog.clear() # Clear logs # update embeddings, entries, notes model after removing an entry from the org file with open(new_file_to_index, "w") as f: @@ -329,11 +345,12 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg # Act with caplog.at_level(logging.INFO): text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user) + final_logs = caplog.text # Assert # verify only 1 entry added even if there are multiple duplicate entries - assert "Created 2 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message - assert "Created 0 new embeddings. Deleted 1 embeddings for user " in caplog.records[4].message + assert "Created 2 new embeddings. Deleted 3 embeddings for user " in initial_logs + assert "Created 0 new embeddings. Deleted 1 embeddings for user " in final_logs verify_embeddings(1, default_user) @@ -346,6 +363,8 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file data = get_org_files(org_config) with caplog.at_level(logging.INFO): text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) + initial_logs = caplog.text + caplog.clear() # Clear logs # append org-mode entry to first org input file in config with open(new_org_file, "w") as f: @@ -358,10 +377,11 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file # update embeddings, entries with the newly added note with caplog.at_level(logging.INFO): text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user) + final_logs = caplog.text # Assert - assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message - assert "Created 1 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message + assert "Created 10 new embeddings. Deleted 3 embeddings for user " in initial_logs + assert "Created 1 new embeddings. Deleted 0 embeddings for user " in final_logs verify_embeddings(11, default_user)