[Multi-User Part 6]: Address small bugs and upstream PR comments (#518)

- 08654163cb: Add better parsing for XML files
- f3acfac7fb: Add a try/catch around the dateparser in order to avoid internal server errors in app
- 7d43cd62c0: Chunk embeddings generation in order to avoid large memory load
- e02d751eb3: Addresses comments from PR #498 
- a3f393edb4: Addresses comments from PR #503 
- 66eb078286: Addresses comments from PR #511 
- Address various items in https://github.com/khoj-ai/khoj/issues/527
This commit is contained in:
sabaimran 2023-10-31 17:59:53 -07:00 committed by GitHub
parent 5f3f6b7c61
commit 54a387326c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 264 additions and 114 deletions

View file

@ -71,6 +71,7 @@ dependencies = [
"google-auth == 2.23.3", "google-auth == 2.23.3",
"python-multipart == 0.0.6", "python-multipart == 0.0.6",
"gunicorn == 21.2.0", "gunicorn == 21.2.0",
"lxml == 4.9.3",
] ]
dynamic = ["version"] dynamic = ["version"]

View file

@ -1,6 +1,9 @@
import secrets import secrets
from typing import Type, TypeVar, List from typing import Type, TypeVar, List
from datetime import date from datetime import date
import secrets
from typing import Type, TypeVar, List
from datetime import date
from django.db import models from django.db import models
from django.contrib.sessions.backends.db import SessionStore from django.contrib.sessions.backends.db import SessionStore
@ -8,6 +11,10 @@ from pgvector.django import CosineDistance
from django.db.models.manager import BaseManager from django.db.models.manager import BaseManager
from django.db.models import Q from django.db.models import Q
from torch import Tensor from torch import Tensor
from pgvector.django import CosineDistance
from django.db.models.manager import BaseManager
from django.db.models import Q
from torch import Tensor
# Import sync_to_async from Django Channels # Import sync_to_async from Django Channels
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
@ -58,7 +65,7 @@ async def set_notion_config(token: str, user: KhojUser):
async def create_khoj_token(user: KhojUser, name=None): async def create_khoj_token(user: KhojUser, name=None):
"Create Khoj API key for user" "Create Khoj API key for user"
token = f"kk-{secrets.token_urlsafe(32)}" token = f"kk-{secrets.token_urlsafe(32)}"
name = name or f"{generate_random_name().title()}'s Secret Key" name = name or f"{generate_random_name().title()}"
api_config = await KhojApiUser.objects.acreate(token=token, user=user, name=name) api_config = await KhojApiUser.objects.acreate(token=token, user=user, name=name)
await api_config.asave() await api_config.asave()
return api_config return api_config
@ -123,15 +130,11 @@ def get_all_users() -> BaseManager[KhojUser]:
def get_user_github_config(user: KhojUser): def get_user_github_config(user: KhojUser):
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
if not config:
return None
return config return config
def get_user_notion_config(user: KhojUser): def get_user_notion_config(user: KhojUser):
config = NotionConfig.objects.filter(user=user).first() config = NotionConfig.objects.filter(user=user).first()
if not config:
return None
return config return config
@ -240,13 +243,10 @@ class ConversationAdapters:
@staticmethod @staticmethod
def get_enabled_conversation_settings(user: KhojUser): def get_enabled_conversation_settings(user: KhojUser):
openai_config = ConversationAdapters.get_openai_conversation_config(user) openai_config = ConversationAdapters.get_openai_conversation_config(user)
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user)
return { return {
"openai": True if openai_config is not None else False, "openai": True if openai_config is not None else False,
"offline_chat": True "offline_chat": ConversationAdapters.has_offline_chat(user),
if (offline_chat_config is not None and offline_chat_config.enable_offline_chat)
else False,
} }
@staticmethod @staticmethod
@ -264,7 +264,11 @@ class ConversationAdapters:
OfflineChatProcessorConversationConfig.objects.filter(user=user).delete() OfflineChatProcessorConversationConfig.objects.filter(user=user).delete()
@staticmethod @staticmethod
async def has_offline_chat(user: KhojUser): def has_offline_chat(user: KhojUser):
return OfflineChatProcessorConversationConfig.objects.filter(user=user, enable_offline_chat=True).exists()
@staticmethod
async def ahas_offline_chat(user: KhojUser):
return await OfflineChatProcessorConversationConfig.objects.filter( return await OfflineChatProcessorConversationConfig.objects.filter(
user=user, enable_offline_chat=True user=user, enable_offline_chat=True
).aexists() ).aexists()

View file

@ -2,7 +2,7 @@
<head> <head>
<meta charset="utf-8"> <meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
<title>Khoj - Search</title> <title>Khoj - Settings</title>
<link rel="icon" type="image/png" sizes="128x128" href="./assets/icons/favicon-128x128.png"> <link rel="icon" type="image/png" sizes="128x128" href="./assets/icons/favicon-128x128.png">
<link rel="manifest" href="./khoj.webmanifest"> <link rel="manifest" href="./khoj.webmanifest">

View file

@ -94,6 +94,15 @@
}).join("\n"); }).join("\n");
} }
function render_xml(query, data) {
return data.map(function (item) {
return `<div class="results-xml">` +
`<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` +
`<xml>${item.entry}</xml>` +
`</div>`
}).join("\n");
}
function render_multiple(query, data, type) { function render_multiple(query, data, type) {
let html = ""; let html = "";
data.forEach(item => { data.forEach(item => {
@ -113,6 +122,8 @@
html += `<div class="results-notion">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`; html += `<div class="results-notion">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
} else if (item.additional.file.endsWith(".html")) { } else if (item.additional.file.endsWith(".html")) {
html += render_html(query, [item]); html += render_html(query, [item]);
} else if (item.additional.file.endsWith(".xml")) {
html += render_xml(query, [item])
} else { } else {
html += `<div class="results-plugin">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`; html += `<div class="results-plugin">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
} }

View file

@ -267,6 +267,12 @@ async function getFolders () {
} }
async function setURL (event, url) { async function setURL (event, url) {
// Sanitize the URL. Remove trailing slash if present. Add http:// if not present.
url = url.replace(/\/$/, "");
if (!url.match(/^[a-zA-Z]+:\/\//)) {
url = `http://${url}`;
}
store.set('hostURL', url); store.set('hostURL', url);
return store.get('hostURL'); return store.get('hostURL');
} }

View file

@ -174,6 +174,7 @@ urlInput.addEventListener('blur', async () => {
new URL(urlInputValue); new URL(urlInputValue);
} catch (e) { } catch (e) {
console.log(e); console.log(e);
alert('Please enter a valid URL');
return; return;
} }

View file

@ -25,7 +25,6 @@ from khoj.utils import constants, state
from khoj.utils.config import ( from khoj.utils.config import (
SearchType, SearchType,
) )
from khoj.utils.helpers import merge_dicts
from khoj.utils.fs_syncer import collect_files from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig from khoj.utils.rawconfig import FullConfig
from khoj.routers.indexer import configure_content, load_content, configure_search from khoj.routers.indexer import configure_content, load_content, configure_search
@ -83,12 +82,6 @@ class UserAuthenticationBackend(AuthenticationBackend):
def initialize_server(config: Optional[FullConfig]): def initialize_server(config: Optional[FullConfig]):
if config is None:
logger.warning(
f"🚨 Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config, plugins or by editing {state.config_file}."
)
return None
try: try:
configure_server(config, init=True) configure_server(config, init=True)
except Exception as e: except Exception as e:
@ -190,7 +183,7 @@ def configure_search_types(config: FullConfig):
core_search_types = {e.name: e.value for e in SearchType} core_search_types = {e.name: e.value for e in SearchType}
# Dynamically generate search type enum by merging core search types with configured plugin search types # Dynamically generate search type enum by merging core search types with configured plugin search types
return Enum("SearchType", merge_dicts(core_search_types, {})) return Enum("SearchType", core_search_types)
@schedule.repeat(schedule.every(59).minutes) @schedule.repeat(schedule.every(59).minutes)

View file

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Pro 6.4.2 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc. --><path d="M208 0H332.1c12.7 0 24.9 5.1 33.9 14.1l67.9 67.9c9 9 14.1 21.2 14.1 33.9V336c0 26.5-21.5 48-48 48H208c-26.5 0-48-21.5-48-48V48c0-26.5 21.5-48 48-48zM48 128h80v64H64V448H256V416h64v48c0 26.5-21.5 48-48 48H48c-26.5 0-48-21.5-48-48V176c0-26.5 21.5-48 48-48z"/></svg>

After

Width:  |  Height:  |  Size: 503 B

View file

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Pro 6.4.2 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc. --><path d="M135.2 17.7L128 32H32C14.3 32 0 46.3 0 64S14.3 96 32 96H416c17.7 0 32-14.3 32-32s-14.3-32-32-32H320l-7.2-14.3C307.4 6.8 296.3 0 284.2 0H163.8c-12.1 0-23.2 6.8-28.6 17.7zM416 128H32L53.2 467c1.6 25.3 22.6 45 47.9 45H346.9c25.3 0 46.3-19.7 47.9-45L416 128z"/></svg>

After

Width:  |  Height:  |  Size: 503 B

View file

@ -52,9 +52,23 @@
justify-self: center; justify-self: center;
} }
.api-settings {
div.section.general-settings { display: grid;
justify-self: center; grid-template-columns: 1fr;
grid-template-rows: 1fr 1fr auto;
justify-items: start;
gap: 8px;
padding: 24px 24px;
background: white;
border: 1px solid rgb(229, 229, 229);
border-radius: 4px;
box-shadow: 0px 1px 3px 0px rgba(0,0,0,0.1),0px 1px 2px -1px rgba(0,0,0,0.1);
}
#api-settings-card-description {
margin: 8px 0 0 0;
}
#api-settings-keys-table {
margin-bottom: 16px;
} }
div.instructions { div.instructions {

View file

@ -287,11 +287,33 @@
</div> </div>
</div> </div>
</div> </div>
<div class="section general-settings"> <div class="section">
<div id="khoj-api-key-section" title="Use Khoj cloud with your Khoj API Key"> <div class="api-settings">
<button id="generate-api-key" onclick="generateAPIKey()">Generate API Key</button> <div class="card-title-row">
<div id="api-key-list"></div> <img class="card-icon" src="/static/assets/icons/key.svg" alt="API Key">
<h3 class="card-title">API Keys</h3>
</div> </div>
<div class="card-description-row">
<p id="api-settings-card-description" class="card-description">Manage access to your Khoj from client apps</p>
</div>
<table id="api-settings-keys-table">
<thead>
<tr>
<th scope="col">Name</th>
<th scope="col">Key</th>
<th scope="col">Actions</th>
</tr>
</thead>
<tbody id="api-key-list"></tbody>
</table>
<div class="card-action-row">
<button class="card-button happy" id="generate-api-key" onclick="generateAPIKey()">
Generate API Key
</button>
</div>
</div>
</div>
<div class="section general-settings">
<div id="results-count" title="Number of items to show in search and use for chat response"> <div id="results-count" title="Number of items to show in search and use for chat response">
<label for="results-count-slider">Results Count: <span id="results-count-value">5</span></label> <label for="results-count-slider">Results Count: <span id="results-count-value">5</span></label>
<input type="range" id="results-count-slider" name="results-count-slider" min="1" max="10" step="1" value="5"> <input type="range" id="results-count-slider" name="results-count-slider" min="1" max="10" step="1" value="5">
@ -520,14 +542,32 @@
.then(response => response.json()) .then(response => response.json())
.then(tokenObj => { .then(tokenObj => {
apiKeyList.innerHTML += ` apiKeyList.innerHTML += `
<div id="api-key-item-${tokenObj.token}" class="api-key-item"> <tr id="api-key-item-${tokenObj.token}">
<span class="api-key">${tokenObj.token}</span> <td><b>${tokenObj.name}</b></td>
<button class="delete-api-key" onclick="deleteAPIKey('${tokenObj.token}')">Delete</button> <td id="api-key-${tokenObj.token}">${tokenObj.token}</td>
</div> <td>
<img id="api-key-copy-button-${tokenObj.token}" onclick="copyAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/copy-solid.svg" alt="Copy API Key">
<img id="api-key-delete-button-${tokenObj.token}" onclick="deleteAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/trash-solid.svg" alt="Delete API Key">
</td>
</tr>
`; `;
}); });
} }
function copyAPIKey(token) {
// Copy API key to clipboard
navigator.clipboard.writeText(token);
// Flash the API key copied message
const copyApiKeyButton = document.getElementById(`api-key-${token}`);
original_html = copyApiKeyButton.innerHTML
setTimeout(function() {
copyApiKeyButton.innerHTML = "✅ Copied to your clipboard!";
setTimeout(function() {
copyApiKeyButton.innerHTML = original_html;
}, 1000);
}, 100);
}
function deleteAPIKey(token) { function deleteAPIKey(token) {
const apiKeyList = document.getElementById("api-key-list"); const apiKeyList = document.getElementById("api-key-list");
fetch(`/auth/token?token=${token}`, { fetch(`/auth/token?token=${token}`, {
@ -548,10 +588,14 @@
.then(tokens => { .then(tokens => {
apiKeyList.innerHTML = tokens.map(tokenObj => apiKeyList.innerHTML = tokens.map(tokenObj =>
` `
<div id="api-key-item-${tokenObj.token}" class="api-key-item"> <tr id="api-key-item-${tokenObj.token}">
<span class="api-key">${tokenObj.token}</span> <td><b>${tokenObj.name}</b></td>
<button class="delete-api-key" onclick="deleteAPIKey('${tokenObj.token}')">Delete</button> <td id="api-key-${tokenObj.token}">${tokenObj.token}</td>
</div> <td>
<img id="api-key-copy-button-${tokenObj.token}" onclick="copyAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/copy-solid.svg" alt="Copy API Key">
<img id="api-key-delete-button-${tokenObj.token}" onclick="deleteAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/trash-solid.svg" alt="Delete API Key">
</td>
</tr>
`) `)
.join(""); .join("");
}); });

View file

@ -94,6 +94,15 @@
}).join("\n"); }).join("\n");
} }
function render_xml(query, data) {
return data.map(function (item) {
return `<div class="results-xml">` +
`<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` +
`<xml>${item.entry}</xml>` +
`</div>`
}).join("\n");
}
function render_multiple(query, data, type) { function render_multiple(query, data, type) {
let html = ""; let html = "";
data.forEach(item => { data.forEach(item => {
@ -113,6 +122,8 @@
html += `<div class="results-notion">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`; html += `<div class="results-notion">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
} else if (item.additional.file.endsWith(".html")) { } else if (item.additional.file.endsWith(".html")) {
html += render_html(query, [item]); html += render_html(query, [item]);
} else if (item.additional.file.endsWith(".xml")) {
html += render_xml(query, [item])
} else { } else {
html += `<div class="results-plugin">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`; html += `<div class="results-plugin">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
} }

View file

@ -2,7 +2,7 @@
<head> <head>
<meta charset="utf-8"> <meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
<title>Khoj - Search</title> <title>Khoj - Login</title>
<link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png"> <link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png">
<link rel="manifest" href="/static/khoj.webmanifest"> <link rel="manifest" href="/static/khoj.webmanifest">

View file

@ -2,12 +2,14 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import List, Tuple from typing import List, Tuple
from bs4 import BeautifulSoup
# Internal Packages # Internal Packages
from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, TextContentConfig from khoj.utils.rawconfig import Entry
from database.models import Embeddings, KhojUser, LocalPlaintextConfig from database.models import Embeddings, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,6 +30,19 @@ class PlaintextToJsonl(TextEmbeddings):
else: else:
deletion_file_names = None deletion_file_names = None
with timer("Scrub plaintext files and extract text", logger):
for file in files:
try:
plaintext_content = files[file]
if file.endswith(("html", "htm", "xml")):
plaintext_content = PlaintextToJsonl.extract_html_content(
plaintext_content, file.split(".")[-1]
)
files[file] = plaintext_content
except Exception as e:
logger.warning(f"Unable to read file: {file} as plaintext. Skipping file.")
logger.warning(e, exc_info=True)
# Extract Entries from specified plaintext files # Extract Entries from specified plaintext files
with timer("Parse entries from plaintext files", logger): with timer("Parse entries from plaintext files", logger):
current_entries = PlaintextToJsonl.convert_plaintext_entries_to_maps(files) current_entries = PlaintextToJsonl.convert_plaintext_entries_to_maps(files)
@ -50,6 +65,15 @@ class PlaintextToJsonl(TextEmbeddings):
return num_new_embeddings, num_deleted_embeddings return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_html_content(markup_content: str, markup_type: str):
"Extract content from HTML"
if markup_type == "xml":
soup = BeautifulSoup(markup_content, "xml")
else:
soup = BeautifulSoup(markup_content, "html.parser")
return soup.get_text(strip=True, separator="\n")
@staticmethod @staticmethod
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]: def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
"Convert each plaintext entries into a dictionary" "Convert each plaintext entries into a dictionary"

View file

@ -5,7 +5,7 @@ import logging
import uuid import uuid
from tqdm import tqdm from tqdm import tqdm
from typing import Callable, List, Tuple, Set, Any from typing import Callable, List, Tuple, Set, Any
from khoj.utils.helpers import timer from khoj.utils.helpers import timer, batcher
# Internal Packages # Internal Packages
@ -93,7 +93,7 @@ class TextEmbeddings(ABC):
num_deleted_embeddings = 0 num_deleted_embeddings = 0
with timer("Preparing dataset for regeneration", logger): with timer("Preparing dataset for regeneration", logger):
if regenerate: if regenerate:
logger.info(f"Deleting all embeddings for file type {file_type}") logger.debug(f"Deleting all embeddings for file type {file_type}")
num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type) num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type)
num_new_embeddings = 0 num_new_embeddings = 0
@ -106,32 +106,38 @@ class TextEmbeddings(ABC):
) )
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries]) existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
hashes_to_process = hashes_for_file - existing_entry_hashes hashes_to_process = hashes_for_file - existing_entry_hashes
# for hashed_val in hashes_for_file:
# if not EmbeddingsAdapters.does_embedding_exist(user, hashed_val):
# hashes_to_process.add(hashed_val)
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process] entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process] data_to_embed = [getattr(entry, key) for entry in entries_to_process]
embeddings = self.embeddings_model.embed_documents(data_to_embed) embeddings = self.embeddings_model.embed_documents(data_to_embed)
with timer("Update the database with new vector embeddings", logger): with timer("Update the database with new vector embeddings", logger):
embeddings_to_create = [] num_items = len(hashes_to_process)
for hashed_val, embedding in zip(hashes_to_process, embeddings): assert num_items == len(embeddings)
entry = hash_to_current_entries[hashed_val] batch_size = min(200, num_items)
embeddings_to_create.append( entry_batches = zip(hashes_to_process, embeddings)
for entry_batch in tqdm(
batcher(entry_batches, batch_size), desc="Processing embeddings in batches"
):
batch_embeddings_to_create = []
for entry_hash, embedding in entry_batch:
entry = hash_to_current_entries[entry_hash]
batch_embeddings_to_create.append(
Embeddings( Embeddings(
user=user, user=user,
embeddings=embedding, embeddings=embedding,
raw=entry.raw, raw=entry.raw,
compiled=entry.compiled, compiled=entry.compiled,
heading=entry.heading, heading=entry.heading[:1000], # Truncate to max chars of field allowed
file_path=entry.file, file_path=entry.file,
file_type=file_type, file_type=file_type,
hashed_value=hashed_val, hashed_value=entry_hash,
corpus_id=entry.corpus_id, corpus_id=entry.corpus_id,
) )
) )
new_embeddings = Embeddings.objects.bulk_create(embeddings_to_create) new_embeddings = Embeddings.objects.bulk_create(batch_embeddings_to_create)
logger.debug(f"Created {len(new_embeddings)} new embeddings")
num_new_embeddings += len(new_embeddings) num_new_embeddings += len(new_embeddings)
dates_to_create = [] dates_to_create = []
@ -147,7 +153,7 @@ class TextEmbeddings(ABC):
) )
new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create) new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create)
if len(new_dates) > 0: if len(new_dates) > 0:
logger.info(f"Created {len(new_dates)} new date entries") logger.debug(f"Created {len(new_dates)} new date entries")
with timer("Identify hashes for removed entries", logger): with timer("Identify hashes for removed entries", logger):
for file in hashes_by_file: for file in hashes_by_file:

View file

@ -723,7 +723,7 @@ async def extract_references_and_questions(
# Infer search queries from user message # Infer search queries from user message
with timer("Extracting search queries took", logger): with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled. # If we've reached here, either the user has enabled offline chat or the openai model is enabled.
if await ConversationAdapters.has_offline_chat(user): if await ConversationAdapters.ahas_offline_chat(user):
using_offline_chat = True using_offline_chat = True
offline_chat = await ConversationAdapters.get_offline_chat(user) offline_chat = await ConversationAdapters.get_offline_chat(user)
chat_model = offline_chat.chat_model chat_model = offline_chat.chat_model

View file

@ -46,7 +46,6 @@ async def login(request: Request):
return await oauth.google.authorize_redirect(request, redirect_uri) return await oauth.google.authorize_redirect(request, redirect_uri)
@auth_router.post("/redirect")
@auth_router.post("/token") @auth_router.post("/token")
@requires(["authenticated"], redirect="login_page") @requires(["authenticated"], redirect="login_page")
async def generate_token(request: Request, token_name: Optional[str] = None) -> str: async def generate_token(request: Request, token_name: Optional[str] = None) -> str:

View file

@ -31,7 +31,7 @@ def perform_chat_checks(user: KhojUser):
async def is_ready_to_chat(user: KhojUser): async def is_ready_to_chat(user: KhojUser):
has_offline_config = await ConversationAdapters.has_offline_chat(user=user) has_offline_config = await ConversationAdapters.ahas_offline_chat(user=user)
has_openai_config = await ConversationAdapters.has_openai_chat(user=user) has_openai_config = await ConversationAdapters.has_openai_chat(user=user)
if has_offline_config: if has_offline_config:

View file

@ -156,18 +156,17 @@ async def update(
host=host, host=host,
) )
logger.info(f"📪 Content index updated via API call by {client} client")
return Response(content="OK", status_code=200) return Response(content="OK", status_code=200)
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]: def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
# Run Validation Checks # Run Validation Checks
if search_config is None:
logger.warning("🚨 No Search configuration available.")
return None
if search_models is None: if search_models is None:
search_models = SearchModels() search_models = SearchModels()
if search_config.image: if search_config and search_config.image:
logger.info("🔍 🌄 Setting up image search model") logger.info("🔍 🌄 Setting up image search model")
search_models.image_search = image_search.initialize_model(search_config.image) search_models.image_search = image_search.initialize_model(search_config.image)

View file

@ -127,6 +127,7 @@ class DateFilter(BaseFilter):
clean_date_str = re.sub("|".join(future_strings), "", date_str) clean_date_str = re.sub("|".join(future_strings), "", date_str)
# parse date passed in query date filter # parse date passed in query date filter
try:
parsed_date = dtparse.parse( parsed_date = dtparse.parse(
clean_date_str, clean_date_str,
settings={ settings={
@ -135,6 +136,9 @@ class DateFilter(BaseFilter):
"PREFER_DATES_FROM": prefer_dates_from, "PREFER_DATES_FROM": prefer_dates_from,
}, },
) )
except Exception as e:
logger.error(f"Failed to parse date string: {date_str} with error: {e}")
return None
if parsed_date is None: if parsed_date is None:
return None return None

View file

@ -5,6 +5,7 @@ import datetime
from enum import Enum from enum import Enum
from importlib import import_module from importlib import import_module
from importlib.metadata import version from importlib.metadata import version
from itertools import islice
import logging import logging
from os import path from os import path
import os import os
@ -305,3 +306,13 @@ def generate_random_name():
name = f"{adjective} {noun}" name = f"{adjective} {noun}"
return name return name
def batcher(iterable, max_n):
"Split an iterable into chunks of size max_n"
it = iter(iterable)
while True:
chunk = list(islice(it, max_n))
if not chunk:
return
yield (x for x in chunk if x is not None)

View file

@ -22,9 +22,7 @@ logger = logging.getLogger(__name__)
# Test # Test
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db @pytest.mark.django_db
def test_text_search_setup_with_missing_file_raises_error( def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: LocalOrgConfig):
org_config_with_only_new_file: LocalOrgConfig, search_config: SearchConfig
):
# Arrange # Arrange
# Ensure file mentioned in org.input-files is missing # Ensure file mentioned in org.input-files is missing
single_new_file = Path(org_config_with_only_new_file.input_files[0]) single_new_file = Path(org_config_with_only_new_file.input_files[0])
@ -70,22 +68,39 @@ def test_text_search_setup_with_empty_file_raises_error(
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
verify_embeddings(0, default_user) verify_embeddings(0, default_user)
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db @pytest.mark.django_db
def test_text_search_setup(content_config, default_user: KhojUser, caplog): def test_text_indexer_deletes_embedding_before_regenerate(
content_config: ContentConfig, default_user: KhojUser, caplog
):
# Arrange # Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first() org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config) data = get_org_files(org_config)
with caplog.at_level(logging.INFO): with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# Assert # Assert
assert "Deleting all embeddings for file type org" in caplog.records[1].message assert "Deleting all embeddings for file type org" in caplog.text
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_search_setup_batch_processes(content_config: ContentConfig, default_user: KhojUser, caplog):
# Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# Assert
assert "Created 4 new embeddings" in caplog.text
assert "Created 6 new embeddings" in caplog.text
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@ -97,13 +112,13 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
# Act # Act
# Generate initial notes embeddings during asymmetric setup # Generate initial notes embeddings during asymmetric setup
with caplog.at_level(logging.INFO): with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
initial_logs = caplog.text initial_logs = caplog.text
caplog.clear() # Clear logs caplog.clear() # Clear logs
# Run asymmetric setup again with no changes to data source. Ensure index is not updated # Run asymmetric setup again with no changes to data source. Ensure index is not updated
with caplog.at_level(logging.INFO): with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user) text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
final_logs = caplog.text final_logs = caplog.text
@ -175,12 +190,10 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
# Assert # Assert
# verify newly added org-mode entry is split by max tokens # verify newly added org-mode entry is split by max tokens
record = caplog.records[1] assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
# @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests")
@pytest.mark.django_db @pytest.mark.django_db
def test_entry_chunking_by_max_tokens_not_full_corpus( def test_entry_chunking_by_max_tokens_not_full_corpus(
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
@ -232,11 +245,9 @@ conda activate khoj
user=default_user, user=default_user,
) )
record = caplog.records[1]
# Assert # Assert
# verify newly added org-mode entry is split by max tokens # verify newly added org-mode entry is split by max tokens
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@ -251,7 +262,7 @@ def test_regenerate_index_with_new_entry(
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
# append org-mode entry to first org input file in config # append org-mode entry to first org input file in config
org_config.input_files = [f"{new_org_file}"] org_config.input_files = [f"{new_org_file}"]
@ -286,20 +297,23 @@ def test_update_index_with_duplicate_entries_in_stable_order(
data = get_org_files(org_config_with_only_new_file) data = get_org_files(org_config_with_only_new_file)
# Act # Act
# load embeddings, entries, notes model after adding new org-mode file # generate embeddings, entries, notes model from scratch after adding new org-mode file
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
data = get_org_files(org_config_with_only_new_file) data = get_org_files(org_config_with_only_new_file)
# update embeddings, entries, notes model after adding new org-mode file # update embeddings, entries, notes model with no new changes
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user) text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
final_logs = caplog.text
# Assert # Assert
# verify only 1 entry added even if there are multiple duplicate entries # verify only 1 entry added even if there are multiple duplicate entries
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message assert "Created 1 new embeddings. Deleted 3 embeddings for user " in initial_logs
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message assert "Created 0 new embeddings. Deleted 0 embeddings for user " in final_logs
verify_embeddings(1, default_user) verify_embeddings(1, default_user)
@ -319,6 +333,8 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
# load embeddings, entries, notes model after adding new org file with 2 entries # load embeddings, entries, notes model after adding new org file with 2 entries
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
# update embeddings, entries, notes model after removing an entry from the org file # update embeddings, entries, notes model after removing an entry from the org file
with open(new_file_to_index, "w") as f: with open(new_file_to_index, "w") as f:
@ -329,11 +345,12 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
# Act # Act
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user) text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
final_logs = caplog.text
# Assert # Assert
# verify only 1 entry added even if there are multiple duplicate entries # verify only 1 entry added even if there are multiple duplicate entries
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message assert "Created 2 new embeddings. Deleted 3 embeddings for user " in initial_logs
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in caplog.records[4].message assert "Created 0 new embeddings. Deleted 1 embeddings for user " in final_logs
verify_embeddings(1, default_user) verify_embeddings(1, default_user)
@ -346,6 +363,8 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
data = get_org_files(org_config) data = get_org_files(org_config)
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
# append org-mode entry to first org input file in config # append org-mode entry to first org input file in config
with open(new_org_file, "w") as f: with open(new_org_file, "w") as f:
@ -358,10 +377,11 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
# update embeddings, entries with the newly added note # update embeddings, entries with the newly added note
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user) text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
final_logs = caplog.text
# Assert # Assert
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message assert "Created 10 new embeddings. Deleted 3 embeddings for user " in initial_logs
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message assert "Created 1 new embeddings. Deleted 0 embeddings for user " in final_logs
verify_embeddings(11, default_user) verify_embeddings(11, default_user)