[Multi-User Part 6]: Address small bugs and upstream PR comments (#518)

- 08654163cb: Add better parsing for XML files
- f3acfac7fb: Add a try/catch around the dateparser in order to avoid internal server errors in app
- 7d43cd62c0: Chunk embeddings generation in order to avoid large memory load
- e02d751eb3: Addresses comments from PR #498 
- a3f393edb4: Addresses comments from PR #503 
- 66eb078286: Addresses comments from PR #511 
- Address various items in https://github.com/khoj-ai/khoj/issues/527
This commit is contained in:
sabaimran 2023-10-31 17:59:53 -07:00 committed by GitHub
parent 5f3f6b7c61
commit 54a387326c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 264 additions and 114 deletions

View file

@ -71,6 +71,7 @@ dependencies = [
"google-auth == 2.23.3",
"python-multipart == 0.0.6",
"gunicorn == 21.2.0",
"lxml == 4.9.3",
]
dynamic = ["version"]

View file

@ -1,6 +1,9 @@
import secrets
from typing import Type, TypeVar, List
from datetime import date
import secrets
from typing import Type, TypeVar, List
from datetime import date
from django.db import models
from django.contrib.sessions.backends.db import SessionStore
@ -8,6 +11,10 @@ from pgvector.django import CosineDistance
from django.db.models.manager import BaseManager
from django.db.models import Q
from torch import Tensor
from pgvector.django import CosineDistance
from django.db.models.manager import BaseManager
from django.db.models import Q
from torch import Tensor
# Import sync_to_async from Django Channels
from asgiref.sync import sync_to_async
@ -58,7 +65,7 @@ async def set_notion_config(token: str, user: KhojUser):
async def create_khoj_token(user: KhojUser, name=None):
"Create Khoj API key for user"
token = f"kk-{secrets.token_urlsafe(32)}"
name = name or f"{generate_random_name().title()}'s Secret Key"
name = name or f"{generate_random_name().title()}"
api_config = await KhojApiUser.objects.acreate(token=token, user=user, name=name)
await api_config.asave()
return api_config
@ -123,15 +130,11 @@ def get_all_users() -> BaseManager[KhojUser]:
def get_user_github_config(user: KhojUser):
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
if not config:
return None
return config
def get_user_notion_config(user: KhojUser):
config = NotionConfig.objects.filter(user=user).first()
if not config:
return None
return config
@ -240,13 +243,10 @@ class ConversationAdapters:
@staticmethod
def get_enabled_conversation_settings(user: KhojUser):
openai_config = ConversationAdapters.get_openai_conversation_config(user)
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user)
return {
"openai": True if openai_config is not None else False,
"offline_chat": True
if (offline_chat_config is not None and offline_chat_config.enable_offline_chat)
else False,
"offline_chat": ConversationAdapters.has_offline_chat(user),
}
@staticmethod
@ -264,7 +264,11 @@ class ConversationAdapters:
OfflineChatProcessorConversationConfig.objects.filter(user=user).delete()
@staticmethod
async def has_offline_chat(user: KhojUser):
def has_offline_chat(user: KhojUser):
return OfflineChatProcessorConversationConfig.objects.filter(user=user, enable_offline_chat=True).exists()
@staticmethod
async def ahas_offline_chat(user: KhojUser):
return await OfflineChatProcessorConversationConfig.objects.filter(
user=user, enable_offline_chat=True
).aexists()

View file

@ -2,7 +2,7 @@
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
<title>Khoj - Search</title>
<title>Khoj - Settings</title>
<link rel="icon" type="image/png" sizes="128x128" href="./assets/icons/favicon-128x128.png">
<link rel="manifest" href="./khoj.webmanifest">

View file

@ -94,6 +94,15 @@
}).join("\n");
}
function render_xml(query, data) {
return data.map(function (item) {
return `<div class="results-xml">` +
`<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` +
`<xml>${item.entry}</xml>` +
`</div>`
}).join("\n");
}
function render_multiple(query, data, type) {
let html = "";
data.forEach(item => {
@ -113,6 +122,8 @@
html += `<div class="results-notion">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
} else if (item.additional.file.endsWith(".html")) {
html += render_html(query, [item]);
} else if (item.additional.file.endsWith(".xml")) {
html += render_xml(query, [item])
} else {
html += `<div class="results-plugin">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
}

View file

@ -267,6 +267,12 @@ async function getFolders () {
}
async function setURL (event, url) {
// Sanitize the URL. Remove trailing slash if present. Add http:// if not present.
url = url.replace(/\/$/, "");
if (!url.match(/^[a-zA-Z]+:\/\//)) {
url = `http://${url}`;
}
store.set('hostURL', url);
return store.get('hostURL');
}

View file

@ -174,6 +174,7 @@ urlInput.addEventListener('blur', async () => {
new URL(urlInputValue);
} catch (e) {
console.log(e);
alert('Please enter a valid URL');
return;
}

View file

@ -25,7 +25,6 @@ from khoj.utils import constants, state
from khoj.utils.config import (
SearchType,
)
from khoj.utils.helpers import merge_dicts
from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig
from khoj.routers.indexer import configure_content, load_content, configure_search
@ -83,12 +82,6 @@ class UserAuthenticationBackend(AuthenticationBackend):
def initialize_server(config: Optional[FullConfig]):
if config is None:
logger.warning(
f"🚨 Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config, plugins or by editing {state.config_file}."
)
return None
try:
configure_server(config, init=True)
except Exception as e:
@ -190,7 +183,7 @@ def configure_search_types(config: FullConfig):
core_search_types = {e.name: e.value for e in SearchType}
# Dynamically generate search type enum by merging core search types with configured plugin search types
return Enum("SearchType", merge_dicts(core_search_types, {}))
return Enum("SearchType", core_search_types)
@schedule.repeat(schedule.every(59).minutes)

View file

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Pro 6.4.2 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc. --><path d="M208 0H332.1c12.7 0 24.9 5.1 33.9 14.1l67.9 67.9c9 9 14.1 21.2 14.1 33.9V336c0 26.5-21.5 48-48 48H208c-26.5 0-48-21.5-48-48V48c0-26.5 21.5-48 48-48zM48 128h80v64H64V448H256V416h64v48c0 26.5-21.5 48-48 48H48c-26.5 0-48-21.5-48-48V176c0-26.5 21.5-48 48-48z"/></svg>

After

Width:  |  Height:  |  Size: 503 B

View file

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Pro 6.4.2 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc. --><path d="M135.2 17.7L128 32H32C14.3 32 0 46.3 0 64S14.3 96 32 96H416c17.7 0 32-14.3 32-32s-14.3-32-32-32H320l-7.2-14.3C307.4 6.8 296.3 0 284.2 0H163.8c-12.1 0-23.2 6.8-28.6 17.7zM416 128H32L53.2 467c1.6 25.3 22.6 45 47.9 45H346.9c25.3 0 46.3-19.7 47.9-45L416 128z"/></svg>

After

Width:  |  Height:  |  Size: 503 B

View file

@ -52,10 +52,24 @@
justify-self: center;
}
div.section.general-settings {
justify-self: center;
}
.api-settings {
display: grid;
grid-template-columns: 1fr;
grid-template-rows: 1fr 1fr auto;
justify-items: start;
gap: 8px;
padding: 24px 24px;
background: white;
border: 1px solid rgb(229, 229, 229);
border-radius: 4px;
box-shadow: 0px 1px 3px 0px rgba(0,0,0,0.1),0px 1px 2px -1px rgba(0,0,0,0.1);
}
#api-settings-card-description {
margin: 8px 0 0 0;
}
#api-settings-keys-table {
margin-bottom: 16px;
}
div.instructions {
font-size: large;

View file

@ -287,11 +287,33 @@
</div>
</div>
</div>
<div class="section general-settings">
<div id="khoj-api-key-section" title="Use Khoj cloud with your Khoj API Key">
<button id="generate-api-key" onclick="generateAPIKey()">Generate API Key</button>
<div id="api-key-list"></div>
<div class="section">
<div class="api-settings">
<div class="card-title-row">
<img class="card-icon" src="/static/assets/icons/key.svg" alt="API Key">
<h3 class="card-title">API Keys</h3>
</div>
<div class="card-description-row">
<p id="api-settings-card-description" class="card-description">Manage access to your Khoj from client apps</p>
</div>
<table id="api-settings-keys-table">
<thead>
<tr>
<th scope="col">Name</th>
<th scope="col">Key</th>
<th scope="col">Actions</th>
</tr>
</thead>
<tbody id="api-key-list"></tbody>
</table>
<div class="card-action-row">
<button class="card-button happy" id="generate-api-key" onclick="generateAPIKey()">
Generate API Key
</button>
</div>
</div>
</div>
<div class="section general-settings">
<div id="results-count" title="Number of items to show in search and use for chat response">
<label for="results-count-slider">Results Count: <span id="results-count-value">5</span></label>
<input type="range" id="results-count-slider" name="results-count-slider" min="1" max="10" step="1" value="5">
@ -520,14 +542,32 @@
.then(response => response.json())
.then(tokenObj => {
apiKeyList.innerHTML += `
<div id="api-key-item-${tokenObj.token}" class="api-key-item">
<span class="api-key">${tokenObj.token}</span>
<button class="delete-api-key" onclick="deleteAPIKey('${tokenObj.token}')">Delete</button>
</div>
<tr id="api-key-item-${tokenObj.token}">
<td><b>${tokenObj.name}</b></td>
<td id="api-key-${tokenObj.token}">${tokenObj.token}</td>
<td>
<img id="api-key-copy-button-${tokenObj.token}" onclick="copyAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/copy-solid.svg" alt="Copy API Key">
<img id="api-key-delete-button-${tokenObj.token}" onclick="deleteAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/trash-solid.svg" alt="Delete API Key">
</td>
</tr>
`;
});
}
function copyAPIKey(token) {
// Copy API key to clipboard
navigator.clipboard.writeText(token);
// Flash the API key copied message
const copyApiKeyButton = document.getElementById(`api-key-${token}`);
original_html = copyApiKeyButton.innerHTML
setTimeout(function() {
copyApiKeyButton.innerHTML = "✅ Copied to your clipboard!";
setTimeout(function() {
copyApiKeyButton.innerHTML = original_html;
}, 1000);
}, 100);
}
function deleteAPIKey(token) {
const apiKeyList = document.getElementById("api-key-list");
fetch(`/auth/token?token=${token}`, {
@ -548,10 +588,14 @@
.then(tokens => {
apiKeyList.innerHTML = tokens.map(tokenObj =>
`
<div id="api-key-item-${tokenObj.token}" class="api-key-item">
<span class="api-key">${tokenObj.token}</span>
<button class="delete-api-key" onclick="deleteAPIKey('${tokenObj.token}')">Delete</button>
</div>
<tr id="api-key-item-${tokenObj.token}">
<td><b>${tokenObj.name}</b></td>
<td id="api-key-${tokenObj.token}">${tokenObj.token}</td>
<td>
<img id="api-key-copy-button-${tokenObj.token}" onclick="copyAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/copy-solid.svg" alt="Copy API Key">
<img id="api-key-delete-button-${tokenObj.token}" onclick="deleteAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/trash-solid.svg" alt="Delete API Key">
</td>
</tr>
`)
.join("");
});

View file

@ -94,6 +94,15 @@
}).join("\n");
}
function render_xml(query, data) {
return data.map(function (item) {
return `<div class="results-xml">` +
`<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` +
`<xml>${item.entry}</xml>` +
`</div>`
}).join("\n");
}
function render_multiple(query, data, type) {
let html = "";
data.forEach(item => {
@ -113,6 +122,8 @@
html += `<div class="results-notion">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
} else if (item.additional.file.endsWith(".html")) {
html += render_html(query, [item]);
} else if (item.additional.file.endsWith(".xml")) {
html += render_xml(query, [item])
} else {
html += `<div class="results-plugin">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
}

View file

@ -2,7 +2,7 @@
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
<title>Khoj - Search</title>
<title>Khoj - Login</title>
<link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png">
<link rel="manifest" href="/static/khoj.webmanifest">

View file

@ -2,12 +2,14 @@
import logging
from pathlib import Path
from typing import List, Tuple
from bs4 import BeautifulSoup
# Internal Packages
from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, TextContentConfig
from database.models import Embeddings, KhojUser, LocalPlaintextConfig
from khoj.utils.rawconfig import Entry
from database.models import Embeddings, KhojUser
logger = logging.getLogger(__name__)
@ -28,6 +30,19 @@ class PlaintextToJsonl(TextEmbeddings):
else:
deletion_file_names = None
with timer("Scrub plaintext files and extract text", logger):
for file in files:
try:
plaintext_content = files[file]
if file.endswith(("html", "htm", "xml")):
plaintext_content = PlaintextToJsonl.extract_html_content(
plaintext_content, file.split(".")[-1]
)
files[file] = plaintext_content
except Exception as e:
logger.warning(f"Unable to read file: {file} as plaintext. Skipping file.")
logger.warning(e, exc_info=True)
# Extract Entries from specified plaintext files
with timer("Parse entries from plaintext files", logger):
current_entries = PlaintextToJsonl.convert_plaintext_entries_to_maps(files)
@ -50,6 +65,15 @@ class PlaintextToJsonl(TextEmbeddings):
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_html_content(markup_content: str, markup_type: str):
"Extract content from HTML"
if markup_type == "xml":
soup = BeautifulSoup(markup_content, "xml")
else:
soup = BeautifulSoup(markup_content, "html.parser")
return soup.get_text(strip=True, separator="\n")
@staticmethod
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
"Convert each plaintext entries into a dictionary"

View file

@ -5,7 +5,7 @@ import logging
import uuid
from tqdm import tqdm
from typing import Callable, List, Tuple, Set, Any
from khoj.utils.helpers import timer
from khoj.utils.helpers import timer, batcher
# Internal Packages
@ -93,7 +93,7 @@ class TextEmbeddings(ABC):
num_deleted_embeddings = 0
with timer("Preparing dataset for regeneration", logger):
if regenerate:
logger.info(f"Deleting all embeddings for file type {file_type}")
logger.debug(f"Deleting all embeddings for file type {file_type}")
num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type)
num_new_embeddings = 0
@ -106,48 +106,54 @@ class TextEmbeddings(ABC):
)
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
hashes_to_process = hashes_for_file - existing_entry_hashes
# for hashed_val in hashes_for_file:
# if not EmbeddingsAdapters.does_embedding_exist(user, hashed_val):
# hashes_to_process.add(hashed_val)
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
embeddings = self.embeddings_model.embed_documents(data_to_embed)
with timer("Update the database with new vector embeddings", logger):
embeddings_to_create = []
for hashed_val, embedding in zip(hashes_to_process, embeddings):
entry = hash_to_current_entries[hashed_val]
embeddings_to_create.append(
Embeddings(
user=user,
embeddings=embedding,
raw=entry.raw,
compiled=entry.compiled,
heading=entry.heading,
file_path=entry.file,
file_type=file_type,
hashed_value=hashed_val,
corpus_id=entry.corpus_id,
)
)
new_embeddings = Embeddings.objects.bulk_create(embeddings_to_create)
num_new_embeddings += len(new_embeddings)
num_items = len(hashes_to_process)
assert num_items == len(embeddings)
batch_size = min(200, num_items)
entry_batches = zip(hashes_to_process, embeddings)
dates_to_create = []
with timer("Create new date associations for new embeddings", logger):
for embedding in new_embeddings:
dates = self.date_filter.extract_dates(embedding.raw)
for date in dates:
dates_to_create.append(
EmbeddingsDates(
date=date,
embeddings=embedding,
)
for entry_batch in tqdm(
batcher(entry_batches, batch_size), desc="Processing embeddings in batches"
):
batch_embeddings_to_create = []
for entry_hash, embedding in entry_batch:
entry = hash_to_current_entries[entry_hash]
batch_embeddings_to_create.append(
Embeddings(
user=user,
embeddings=embedding,
raw=entry.raw,
compiled=entry.compiled,
heading=entry.heading[:1000], # Truncate to max chars of field allowed
file_path=entry.file,
file_type=file_type,
hashed_value=entry_hash,
corpus_id=entry.corpus_id,
)
new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create)
if len(new_dates) > 0:
logger.info(f"Created {len(new_dates)} new date entries")
)
new_embeddings = Embeddings.objects.bulk_create(batch_embeddings_to_create)
logger.debug(f"Created {len(new_embeddings)} new embeddings")
num_new_embeddings += len(new_embeddings)
dates_to_create = []
with timer("Create new date associations for new embeddings", logger):
for embedding in new_embeddings:
dates = self.date_filter.extract_dates(embedding.raw)
for date in dates:
dates_to_create.append(
EmbeddingsDates(
date=date,
embeddings=embedding,
)
)
new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create)
if len(new_dates) > 0:
logger.debug(f"Created {len(new_dates)} new date entries")
with timer("Identify hashes for removed entries", logger):
for file in hashes_by_file:

View file

@ -723,7 +723,7 @@ async def extract_references_and_questions(
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
if await ConversationAdapters.has_offline_chat(user):
if await ConversationAdapters.ahas_offline_chat(user):
using_offline_chat = True
offline_chat = await ConversationAdapters.get_offline_chat(user)
chat_model = offline_chat.chat_model

View file

@ -46,7 +46,6 @@ async def login(request: Request):
return await oauth.google.authorize_redirect(request, redirect_uri)
@auth_router.post("/redirect")
@auth_router.post("/token")
@requires(["authenticated"], redirect="login_page")
async def generate_token(request: Request, token_name: Optional[str] = None) -> str:

View file

@ -31,7 +31,7 @@ def perform_chat_checks(user: KhojUser):
async def is_ready_to_chat(user: KhojUser):
has_offline_config = await ConversationAdapters.has_offline_chat(user=user)
has_offline_config = await ConversationAdapters.ahas_offline_chat(user=user)
has_openai_config = await ConversationAdapters.has_openai_chat(user=user)
if has_offline_config:

View file

@ -156,18 +156,17 @@ async def update(
host=host,
)
logger.info(f"📪 Content index updated via API call by {client} client")
return Response(content="OK", status_code=200)
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
# Run Validation Checks
if search_config is None:
logger.warning("🚨 No Search configuration available.")
return None
if search_models is None:
search_models = SearchModels()
if search_config.image:
if search_config and search_config.image:
logger.info("🔍 🌄 Setting up image search model")
search_models.image_search = image_search.initialize_model(search_config.image)

View file

@ -127,14 +127,18 @@ class DateFilter(BaseFilter):
clean_date_str = re.sub("|".join(future_strings), "", date_str)
# parse date passed in query date filter
parsed_date = dtparse.parse(
clean_date_str,
settings={
"RELATIVE_BASE": relative_base or datetime.now(),
"PREFER_DAY_OF_MONTH": "first",
"PREFER_DATES_FROM": prefer_dates_from,
},
)
try:
parsed_date = dtparse.parse(
clean_date_str,
settings={
"RELATIVE_BASE": relative_base or datetime.now(),
"PREFER_DAY_OF_MONTH": "first",
"PREFER_DATES_FROM": prefer_dates_from,
},
)
except Exception as e:
logger.error(f"Failed to parse date string: {date_str} with error: {e}")
return None
if parsed_date is None:
return None

View file

@ -5,6 +5,7 @@ import datetime
from enum import Enum
from importlib import import_module
from importlib.metadata import version
from itertools import islice
import logging
from os import path
import os
@ -305,3 +306,13 @@ def generate_random_name():
name = f"{adjective} {noun}"
return name
def batcher(iterable, max_n):
"Split an iterable into chunks of size max_n"
it = iter(iterable)
while True:
chunk = list(islice(it, max_n))
if not chunk:
return
yield (x for x in chunk if x is not None)

View file

@ -22,9 +22,7 @@ logger = logging.getLogger(__name__)
# Test
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_search_setup_with_missing_file_raises_error(
org_config_with_only_new_file: LocalOrgConfig, search_config: SearchConfig
):
def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: LocalOrgConfig):
# Arrange
# Ensure file mentioned in org.input-files is missing
single_new_file = Path(org_config_with_only_new_file.input_files[0])
@ -70,22 +68,39 @@ def test_text_search_setup_with_empty_file_raises_error(
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
verify_embeddings(0, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_search_setup(content_config, default_user: KhojUser, caplog):
def test_text_indexer_deletes_embedding_before_regenerate(
content_config: ContentConfig, default_user: KhojUser, caplog
):
# Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
with caplog.at_level(logging.INFO):
with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# Assert
assert "Deleting all embeddings for file type org" in caplog.records[1].message
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
assert "Deleting all embeddings for file type org" in caplog.text
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_search_setup_batch_processes(content_config: ContentConfig, default_user: KhojUser, caplog):
# Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# Assert
assert "Created 4 new embeddings" in caplog.text
assert "Created 6 new embeddings" in caplog.text
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
# ----------------------------------------------------------------------------------------------------
@ -97,13 +112,13 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
# Act
# Generate initial notes embeddings during asymmetric setup
with caplog.at_level(logging.INFO):
with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
with caplog.at_level(logging.INFO):
with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
final_logs = caplog.text
@ -175,12 +190,10 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
# Assert
# verify newly added org-mode entry is split by max tokens
record = caplog.records[1]
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
# ----------------------------------------------------------------------------------------------------
# @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests")
@pytest.mark.django_db
def test_entry_chunking_by_max_tokens_not_full_corpus(
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
@ -232,11 +245,9 @@ conda activate khoj
user=default_user,
)
record = caplog.records[1]
# Assert
# verify newly added org-mode entry is split by max tokens
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
# ----------------------------------------------------------------------------------------------------
@ -251,7 +262,7 @@ def test_regenerate_index_with_new_entry(
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
# append org-mode entry to first org input file in config
org_config.input_files = [f"{new_org_file}"]
@ -286,20 +297,23 @@ def test_update_index_with_duplicate_entries_in_stable_order(
data = get_org_files(org_config_with_only_new_file)
# Act
# load embeddings, entries, notes model after adding new org-mode file
# generate embeddings, entries, notes model from scratch after adding new org-mode file
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
data = get_org_files(org_config_with_only_new_file)
# update embeddings, entries, notes model after adding new org-mode file
# update embeddings, entries, notes model with no new changes
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
final_logs = caplog.text
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in initial_logs
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in final_logs
verify_embeddings(1, default_user)
@ -319,6 +333,8 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
# load embeddings, entries, notes model after adding new org file with 2 entries
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
# update embeddings, entries, notes model after removing an entry from the org file
with open(new_file_to_index, "w") as f:
@ -329,11 +345,12 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
# Act
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
final_logs = caplog.text
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in caplog.records[4].message
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in initial_logs
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in final_logs
verify_embeddings(1, default_user)
@ -346,6 +363,8 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
data = get_org_files(org_config)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
# append org-mode entry to first org input file in config
with open(new_org_file, "w") as f:
@ -358,10 +377,11 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
# update embeddings, entries with the newly added note
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
final_logs = caplog.text
# Assert
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in initial_logs
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in final_logs
verify_embeddings(11, default_user)