mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
[Multi-User Part 6]: Address small bugs and upstream PR comments (#518)
-08654163cb
: Add better parsing for XML files -f3acfac7fb
: Add a try/catch around the dateparser in order to avoid internal server errors in app -7d43cd62c0
: Chunk embeddings generation in order to avoid large memory load -e02d751eb3
: Addresses comments from PR #498 -a3f393edb4
: Addresses comments from PR #503 -66eb078286
: Addresses comments from PR #511 - Address various items in https://github.com/khoj-ai/khoj/issues/527
This commit is contained in:
parent
5f3f6b7c61
commit
54a387326c
22 changed files with 264 additions and 114 deletions
|
@ -71,6 +71,7 @@ dependencies = [
|
|||
"google-auth == 2.23.3",
|
||||
"python-multipart == 0.0.6",
|
||||
"gunicorn == 21.2.0",
|
||||
"lxml == 4.9.3",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import secrets
|
||||
from typing import Type, TypeVar, List
|
||||
from datetime import date
|
||||
import secrets
|
||||
from typing import Type, TypeVar, List
|
||||
from datetime import date
|
||||
|
||||
from django.db import models
|
||||
from django.contrib.sessions.backends.db import SessionStore
|
||||
|
@ -8,6 +11,10 @@ from pgvector.django import CosineDistance
|
|||
from django.db.models.manager import BaseManager
|
||||
from django.db.models import Q
|
||||
from torch import Tensor
|
||||
from pgvector.django import CosineDistance
|
||||
from django.db.models.manager import BaseManager
|
||||
from django.db.models import Q
|
||||
from torch import Tensor
|
||||
|
||||
# Import sync_to_async from Django Channels
|
||||
from asgiref.sync import sync_to_async
|
||||
|
@ -58,7 +65,7 @@ async def set_notion_config(token: str, user: KhojUser):
|
|||
async def create_khoj_token(user: KhojUser, name=None):
|
||||
"Create Khoj API key for user"
|
||||
token = f"kk-{secrets.token_urlsafe(32)}"
|
||||
name = name or f"{generate_random_name().title()}'s Secret Key"
|
||||
name = name or f"{generate_random_name().title()}"
|
||||
api_config = await KhojApiUser.objects.acreate(token=token, user=user, name=name)
|
||||
await api_config.asave()
|
||||
return api_config
|
||||
|
@ -123,15 +130,11 @@ def get_all_users() -> BaseManager[KhojUser]:
|
|||
|
||||
def get_user_github_config(user: KhojUser):
|
||||
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||
if not config:
|
||||
return None
|
||||
return config
|
||||
|
||||
|
||||
def get_user_notion_config(user: KhojUser):
|
||||
config = NotionConfig.objects.filter(user=user).first()
|
||||
if not config:
|
||||
return None
|
||||
return config
|
||||
|
||||
|
||||
|
@ -240,13 +243,10 @@ class ConversationAdapters:
|
|||
@staticmethod
|
||||
def get_enabled_conversation_settings(user: KhojUser):
|
||||
openai_config = ConversationAdapters.get_openai_conversation_config(user)
|
||||
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user)
|
||||
|
||||
return {
|
||||
"openai": True if openai_config is not None else False,
|
||||
"offline_chat": True
|
||||
if (offline_chat_config is not None and offline_chat_config.enable_offline_chat)
|
||||
else False,
|
||||
"offline_chat": ConversationAdapters.has_offline_chat(user),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
@ -264,7 +264,11 @@ class ConversationAdapters:
|
|||
OfflineChatProcessorConversationConfig.objects.filter(user=user).delete()
|
||||
|
||||
@staticmethod
|
||||
async def has_offline_chat(user: KhojUser):
|
||||
def has_offline_chat(user: KhojUser):
|
||||
return OfflineChatProcessorConversationConfig.objects.filter(user=user, enable_offline_chat=True).exists()
|
||||
|
||||
@staticmethod
|
||||
async def ahas_offline_chat(user: KhojUser):
|
||||
return await OfflineChatProcessorConversationConfig.objects.filter(
|
||||
user=user, enable_offline_chat=True
|
||||
).aexists()
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
|
||||
<title>Khoj - Search</title>
|
||||
<title>Khoj - Settings</title>
|
||||
|
||||
<link rel="icon" type="image/png" sizes="128x128" href="./assets/icons/favicon-128x128.png">
|
||||
<link rel="manifest" href="./khoj.webmanifest">
|
||||
|
|
|
@ -94,6 +94,15 @@
|
|||
}).join("\n");
|
||||
}
|
||||
|
||||
function render_xml(query, data) {
|
||||
return data.map(function (item) {
|
||||
return `<div class="results-xml">` +
|
||||
`<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` +
|
||||
`<xml>${item.entry}</xml>` +
|
||||
`</div>`
|
||||
}).join("\n");
|
||||
}
|
||||
|
||||
function render_multiple(query, data, type) {
|
||||
let html = "";
|
||||
data.forEach(item => {
|
||||
|
@ -113,6 +122,8 @@
|
|||
html += `<div class="results-notion">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
|
||||
} else if (item.additional.file.endsWith(".html")) {
|
||||
html += render_html(query, [item]);
|
||||
} else if (item.additional.file.endsWith(".xml")) {
|
||||
html += render_xml(query, [item])
|
||||
} else {
|
||||
html += `<div class="results-plugin">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
|
||||
}
|
||||
|
|
|
@ -267,6 +267,12 @@ async function getFolders () {
|
|||
}
|
||||
|
||||
async function setURL (event, url) {
|
||||
// Sanitize the URL. Remove trailing slash if present. Add http:// if not present.
|
||||
url = url.replace(/\/$/, "");
|
||||
if (!url.match(/^[a-zA-Z]+:\/\//)) {
|
||||
url = `http://${url}`;
|
||||
}
|
||||
|
||||
store.set('hostURL', url);
|
||||
return store.get('hostURL');
|
||||
}
|
||||
|
|
|
@ -174,6 +174,7 @@ urlInput.addEventListener('blur', async () => {
|
|||
new URL(urlInputValue);
|
||||
} catch (e) {
|
||||
console.log(e);
|
||||
alert('Please enter a valid URL');
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@ from khoj.utils import constants, state
|
|||
from khoj.utils.config import (
|
||||
SearchType,
|
||||
)
|
||||
from khoj.utils.helpers import merge_dicts
|
||||
from khoj.utils.fs_syncer import collect_files
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||
|
@ -83,12 +82,6 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
|
||||
|
||||
def initialize_server(config: Optional[FullConfig]):
|
||||
if config is None:
|
||||
logger.warning(
|
||||
f"🚨 Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config, plugins or by editing {state.config_file}."
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
configure_server(config, init=True)
|
||||
except Exception as e:
|
||||
|
@ -190,7 +183,7 @@ def configure_search_types(config: FullConfig):
|
|||
core_search_types = {e.name: e.value for e in SearchType}
|
||||
|
||||
# Dynamically generate search type enum by merging core search types with configured plugin search types
|
||||
return Enum("SearchType", merge_dicts(core_search_types, {}))
|
||||
return Enum("SearchType", core_search_types)
|
||||
|
||||
|
||||
@schedule.repeat(schedule.every(59).minutes)
|
||||
|
|
1
src/khoj/interface/web/assets/icons/copy-solid.svg
Normal file
1
src/khoj/interface/web/assets/icons/copy-solid.svg
Normal file
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Pro 6.4.2 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc. --><path d="M208 0H332.1c12.7 0 24.9 5.1 33.9 14.1l67.9 67.9c9 9 14.1 21.2 14.1 33.9V336c0 26.5-21.5 48-48 48H208c-26.5 0-48-21.5-48-48V48c0-26.5 21.5-48 48-48zM48 128h80v64H64V448H256V416h64v48c0 26.5-21.5 48-48 48H48c-26.5 0-48-21.5-48-48V176c0-26.5 21.5-48 48-48z"/></svg>
|
After Width: | Height: | Size: 503 B |
1
src/khoj/interface/web/assets/icons/trash-solid.svg
Normal file
1
src/khoj/interface/web/assets/icons/trash-solid.svg
Normal file
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Pro 6.4.2 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc. --><path d="M135.2 17.7L128 32H32C14.3 32 0 46.3 0 64S14.3 96 32 96H416c17.7 0 32-14.3 32-32s-14.3-32-32-32H320l-7.2-14.3C307.4 6.8 296.3 0 284.2 0H163.8c-12.1 0-23.2 6.8-28.6 17.7zM416 128H32L53.2 467c1.6 25.3 22.6 45 47.9 45H346.9c25.3 0 46.3-19.7 47.9-45L416 128z"/></svg>
|
After Width: | Height: | Size: 503 B |
|
@ -52,10 +52,24 @@
|
|||
justify-self: center;
|
||||
}
|
||||
|
||||
|
||||
div.section.general-settings {
|
||||
justify-self: center;
|
||||
}
|
||||
.api-settings {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr;
|
||||
grid-template-rows: 1fr 1fr auto;
|
||||
justify-items: start;
|
||||
gap: 8px;
|
||||
padding: 24px 24px;
|
||||
background: white;
|
||||
border: 1px solid rgb(229, 229, 229);
|
||||
border-radius: 4px;
|
||||
box-shadow: 0px 1px 3px 0px rgba(0,0,0,0.1),0px 1px 2px -1px rgba(0,0,0,0.1);
|
||||
}
|
||||
#api-settings-card-description {
|
||||
margin: 8px 0 0 0;
|
||||
}
|
||||
#api-settings-keys-table {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
div.instructions {
|
||||
font-size: large;
|
||||
|
|
|
@ -287,11 +287,33 @@
|
|||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section general-settings">
|
||||
<div id="khoj-api-key-section" title="Use Khoj cloud with your Khoj API Key">
|
||||
<button id="generate-api-key" onclick="generateAPIKey()">Generate API Key</button>
|
||||
<div id="api-key-list"></div>
|
||||
<div class="section">
|
||||
<div class="api-settings">
|
||||
<div class="card-title-row">
|
||||
<img class="card-icon" src="/static/assets/icons/key.svg" alt="API Key">
|
||||
<h3 class="card-title">API Keys</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<p id="api-settings-card-description" class="card-description">Manage access to your Khoj from client apps</p>
|
||||
</div>
|
||||
<table id="api-settings-keys-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th scope="col">Name</th>
|
||||
<th scope="col">Key</th>
|
||||
<th scope="col">Actions</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody id="api-key-list"></tbody>
|
||||
</table>
|
||||
<div class="card-action-row">
|
||||
<button class="card-button happy" id="generate-api-key" onclick="generateAPIKey()">
|
||||
Generate API Key
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section general-settings">
|
||||
<div id="results-count" title="Number of items to show in search and use for chat response">
|
||||
<label for="results-count-slider">Results Count: <span id="results-count-value">5</span></label>
|
||||
<input type="range" id="results-count-slider" name="results-count-slider" min="1" max="10" step="1" value="5">
|
||||
|
@ -520,14 +542,32 @@
|
|||
.then(response => response.json())
|
||||
.then(tokenObj => {
|
||||
apiKeyList.innerHTML += `
|
||||
<div id="api-key-item-${tokenObj.token}" class="api-key-item">
|
||||
<span class="api-key">${tokenObj.token}</span>
|
||||
<button class="delete-api-key" onclick="deleteAPIKey('${tokenObj.token}')">Delete</button>
|
||||
</div>
|
||||
<tr id="api-key-item-${tokenObj.token}">
|
||||
<td><b>${tokenObj.name}</b></td>
|
||||
<td id="api-key-${tokenObj.token}">${tokenObj.token}</td>
|
||||
<td>
|
||||
<img id="api-key-copy-button-${tokenObj.token}" onclick="copyAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/copy-solid.svg" alt="Copy API Key">
|
||||
<img id="api-key-delete-button-${tokenObj.token}" onclick="deleteAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/trash-solid.svg" alt="Delete API Key">
|
||||
</td>
|
||||
</tr>
|
||||
`;
|
||||
});
|
||||
}
|
||||
|
||||
function copyAPIKey(token) {
|
||||
// Copy API key to clipboard
|
||||
navigator.clipboard.writeText(token);
|
||||
// Flash the API key copied message
|
||||
const copyApiKeyButton = document.getElementById(`api-key-${token}`);
|
||||
original_html = copyApiKeyButton.innerHTML
|
||||
setTimeout(function() {
|
||||
copyApiKeyButton.innerHTML = "✅ Copied to your clipboard!";
|
||||
setTimeout(function() {
|
||||
copyApiKeyButton.innerHTML = original_html;
|
||||
}, 1000);
|
||||
}, 100);
|
||||
}
|
||||
|
||||
function deleteAPIKey(token) {
|
||||
const apiKeyList = document.getElementById("api-key-list");
|
||||
fetch(`/auth/token?token=${token}`, {
|
||||
|
@ -548,10 +588,14 @@
|
|||
.then(tokens => {
|
||||
apiKeyList.innerHTML = tokens.map(tokenObj =>
|
||||
`
|
||||
<div id="api-key-item-${tokenObj.token}" class="api-key-item">
|
||||
<span class="api-key">${tokenObj.token}</span>
|
||||
<button class="delete-api-key" onclick="deleteAPIKey('${tokenObj.token}')">Delete</button>
|
||||
</div>
|
||||
<tr id="api-key-item-${tokenObj.token}">
|
||||
<td><b>${tokenObj.name}</b></td>
|
||||
<td id="api-key-${tokenObj.token}">${tokenObj.token}</td>
|
||||
<td>
|
||||
<img id="api-key-copy-button-${tokenObj.token}" onclick="copyAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/copy-solid.svg" alt="Copy API Key">
|
||||
<img id="api-key-delete-button-${tokenObj.token}" onclick="deleteAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/trash-solid.svg" alt="Delete API Key">
|
||||
</td>
|
||||
</tr>
|
||||
`)
|
||||
.join("");
|
||||
});
|
||||
|
|
|
@ -94,6 +94,15 @@
|
|||
}).join("\n");
|
||||
}
|
||||
|
||||
function render_xml(query, data) {
|
||||
return data.map(function (item) {
|
||||
return `<div class="results-xml">` +
|
||||
`<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` +
|
||||
`<xml>${item.entry}</xml>` +
|
||||
`</div>`
|
||||
}).join("\n");
|
||||
}
|
||||
|
||||
function render_multiple(query, data, type) {
|
||||
let html = "";
|
||||
data.forEach(item => {
|
||||
|
@ -113,6 +122,8 @@
|
|||
html += `<div class="results-notion">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
|
||||
} else if (item.additional.file.endsWith(".html")) {
|
||||
html += render_html(query, [item]);
|
||||
} else if (item.additional.file.endsWith(".xml")) {
|
||||
html += render_xml(query, [item])
|
||||
} else {
|
||||
html += `<div class="results-plugin">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
|
||||
<title>Khoj - Search</title>
|
||||
<title>Khoj - Login</title>
|
||||
|
||||
<link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png">
|
||||
<link rel="manifest" href="/static/khoj.webmanifest">
|
||||
|
|
|
@ -2,12 +2,14 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||
from database.models import Embeddings, KhojUser, LocalPlaintextConfig
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from database.models import Embeddings, KhojUser
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -28,6 +30,19 @@ class PlaintextToJsonl(TextEmbeddings):
|
|||
else:
|
||||
deletion_file_names = None
|
||||
|
||||
with timer("Scrub plaintext files and extract text", logger):
|
||||
for file in files:
|
||||
try:
|
||||
plaintext_content = files[file]
|
||||
if file.endswith(("html", "htm", "xml")):
|
||||
plaintext_content = PlaintextToJsonl.extract_html_content(
|
||||
plaintext_content, file.split(".")[-1]
|
||||
)
|
||||
files[file] = plaintext_content
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to read file: {file} as plaintext. Skipping file.")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
# Extract Entries from specified plaintext files
|
||||
with timer("Parse entries from plaintext files", logger):
|
||||
current_entries = PlaintextToJsonl.convert_plaintext_entries_to_maps(files)
|
||||
|
@ -50,6 +65,15 @@ class PlaintextToJsonl(TextEmbeddings):
|
|||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def extract_html_content(markup_content: str, markup_type: str):
|
||||
"Extract content from HTML"
|
||||
if markup_type == "xml":
|
||||
soup = BeautifulSoup(markup_content, "xml")
|
||||
else:
|
||||
soup = BeautifulSoup(markup_content, "html.parser")
|
||||
return soup.get_text(strip=True, separator="\n")
|
||||
|
||||
@staticmethod
|
||||
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
|
||||
"Convert each plaintext entries into a dictionary"
|
||||
|
|
|
@ -5,7 +5,7 @@ import logging
|
|||
import uuid
|
||||
from tqdm import tqdm
|
||||
from typing import Callable, List, Tuple, Set, Any
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.helpers import timer, batcher
|
||||
|
||||
|
||||
# Internal Packages
|
||||
|
@ -93,7 +93,7 @@ class TextEmbeddings(ABC):
|
|||
num_deleted_embeddings = 0
|
||||
with timer("Preparing dataset for regeneration", logger):
|
||||
if regenerate:
|
||||
logger.info(f"Deleting all embeddings for file type {file_type}")
|
||||
logger.debug(f"Deleting all embeddings for file type {file_type}")
|
||||
num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type)
|
||||
|
||||
num_new_embeddings = 0
|
||||
|
@ -106,48 +106,54 @@ class TextEmbeddings(ABC):
|
|||
)
|
||||
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
|
||||
hashes_to_process = hashes_for_file - existing_entry_hashes
|
||||
# for hashed_val in hashes_for_file:
|
||||
# if not EmbeddingsAdapters.does_embedding_exist(user, hashed_val):
|
||||
# hashes_to_process.add(hashed_val)
|
||||
|
||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||
embeddings = self.embeddings_model.embed_documents(data_to_embed)
|
||||
|
||||
with timer("Update the database with new vector embeddings", logger):
|
||||
embeddings_to_create = []
|
||||
for hashed_val, embedding in zip(hashes_to_process, embeddings):
|
||||
entry = hash_to_current_entries[hashed_val]
|
||||
embeddings_to_create.append(
|
||||
Embeddings(
|
||||
user=user,
|
||||
embeddings=embedding,
|
||||
raw=entry.raw,
|
||||
compiled=entry.compiled,
|
||||
heading=entry.heading,
|
||||
file_path=entry.file,
|
||||
file_type=file_type,
|
||||
hashed_value=hashed_val,
|
||||
corpus_id=entry.corpus_id,
|
||||
)
|
||||
)
|
||||
new_embeddings = Embeddings.objects.bulk_create(embeddings_to_create)
|
||||
num_new_embeddings += len(new_embeddings)
|
||||
num_items = len(hashes_to_process)
|
||||
assert num_items == len(embeddings)
|
||||
batch_size = min(200, num_items)
|
||||
entry_batches = zip(hashes_to_process, embeddings)
|
||||
|
||||
dates_to_create = []
|
||||
with timer("Create new date associations for new embeddings", logger):
|
||||
for embedding in new_embeddings:
|
||||
dates = self.date_filter.extract_dates(embedding.raw)
|
||||
for date in dates:
|
||||
dates_to_create.append(
|
||||
EmbeddingsDates(
|
||||
date=date,
|
||||
embeddings=embedding,
|
||||
)
|
||||
for entry_batch in tqdm(
|
||||
batcher(entry_batches, batch_size), desc="Processing embeddings in batches"
|
||||
):
|
||||
batch_embeddings_to_create = []
|
||||
for entry_hash, embedding in entry_batch:
|
||||
entry = hash_to_current_entries[entry_hash]
|
||||
batch_embeddings_to_create.append(
|
||||
Embeddings(
|
||||
user=user,
|
||||
embeddings=embedding,
|
||||
raw=entry.raw,
|
||||
compiled=entry.compiled,
|
||||
heading=entry.heading[:1000], # Truncate to max chars of field allowed
|
||||
file_path=entry.file,
|
||||
file_type=file_type,
|
||||
hashed_value=entry_hash,
|
||||
corpus_id=entry.corpus_id,
|
||||
)
|
||||
new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create)
|
||||
if len(new_dates) > 0:
|
||||
logger.info(f"Created {len(new_dates)} new date entries")
|
||||
)
|
||||
new_embeddings = Embeddings.objects.bulk_create(batch_embeddings_to_create)
|
||||
logger.debug(f"Created {len(new_embeddings)} new embeddings")
|
||||
num_new_embeddings += len(new_embeddings)
|
||||
|
||||
dates_to_create = []
|
||||
with timer("Create new date associations for new embeddings", logger):
|
||||
for embedding in new_embeddings:
|
||||
dates = self.date_filter.extract_dates(embedding.raw)
|
||||
for date in dates:
|
||||
dates_to_create.append(
|
||||
EmbeddingsDates(
|
||||
date=date,
|
||||
embeddings=embedding,
|
||||
)
|
||||
)
|
||||
new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create)
|
||||
if len(new_dates) > 0:
|
||||
logger.debug(f"Created {len(new_dates)} new date entries")
|
||||
|
||||
with timer("Identify hashes for removed entries", logger):
|
||||
for file in hashes_by_file:
|
||||
|
|
|
@ -723,7 +723,7 @@ async def extract_references_and_questions(
|
|||
# Infer search queries from user message
|
||||
with timer("Extracting search queries took", logger):
|
||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||
if await ConversationAdapters.has_offline_chat(user):
|
||||
if await ConversationAdapters.ahas_offline_chat(user):
|
||||
using_offline_chat = True
|
||||
offline_chat = await ConversationAdapters.get_offline_chat(user)
|
||||
chat_model = offline_chat.chat_model
|
||||
|
|
|
@ -46,7 +46,6 @@ async def login(request: Request):
|
|||
return await oauth.google.authorize_redirect(request, redirect_uri)
|
||||
|
||||
|
||||
@auth_router.post("/redirect")
|
||||
@auth_router.post("/token")
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
async def generate_token(request: Request, token_name: Optional[str] = None) -> str:
|
||||
|
|
|
@ -31,7 +31,7 @@ def perform_chat_checks(user: KhojUser):
|
|||
|
||||
|
||||
async def is_ready_to_chat(user: KhojUser):
|
||||
has_offline_config = await ConversationAdapters.has_offline_chat(user=user)
|
||||
has_offline_config = await ConversationAdapters.ahas_offline_chat(user=user)
|
||||
has_openai_config = await ConversationAdapters.has_openai_chat(user=user)
|
||||
|
||||
if has_offline_config:
|
||||
|
|
|
@ -156,18 +156,17 @@ async def update(
|
|||
host=host,
|
||||
)
|
||||
|
||||
logger.info(f"📪 Content index updated via API call by {client} client")
|
||||
|
||||
return Response(content="OK", status_code=200)
|
||||
|
||||
|
||||
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
|
||||
# Run Validation Checks
|
||||
if search_config is None:
|
||||
logger.warning("🚨 No Search configuration available.")
|
||||
return None
|
||||
if search_models is None:
|
||||
search_models = SearchModels()
|
||||
|
||||
if search_config.image:
|
||||
if search_config and search_config.image:
|
||||
logger.info("🔍 🌄 Setting up image search model")
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
||||
|
|
|
@ -127,14 +127,18 @@ class DateFilter(BaseFilter):
|
|||
clean_date_str = re.sub("|".join(future_strings), "", date_str)
|
||||
|
||||
# parse date passed in query date filter
|
||||
parsed_date = dtparse.parse(
|
||||
clean_date_str,
|
||||
settings={
|
||||
"RELATIVE_BASE": relative_base or datetime.now(),
|
||||
"PREFER_DAY_OF_MONTH": "first",
|
||||
"PREFER_DATES_FROM": prefer_dates_from,
|
||||
},
|
||||
)
|
||||
try:
|
||||
parsed_date = dtparse.parse(
|
||||
clean_date_str,
|
||||
settings={
|
||||
"RELATIVE_BASE": relative_base or datetime.now(),
|
||||
"PREFER_DAY_OF_MONTH": "first",
|
||||
"PREFER_DATES_FROM": prefer_dates_from,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse date string: {date_str} with error: {e}")
|
||||
return None
|
||||
|
||||
if parsed_date is None:
|
||||
return None
|
||||
|
|
|
@ -5,6 +5,7 @@ import datetime
|
|||
from enum import Enum
|
||||
from importlib import import_module
|
||||
from importlib.metadata import version
|
||||
from itertools import islice
|
||||
import logging
|
||||
from os import path
|
||||
import os
|
||||
|
@ -305,3 +306,13 @@ def generate_random_name():
|
|||
name = f"{adjective} {noun}"
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def batcher(iterable, max_n):
|
||||
"Split an iterable into chunks of size max_n"
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
chunk = list(islice(it, max_n))
|
||||
if not chunk:
|
||||
return
|
||||
yield (x for x in chunk if x is not None)
|
||||
|
|
|
@ -22,9 +22,7 @@ logger = logging.getLogger(__name__)
|
|||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_with_missing_file_raises_error(
|
||||
org_config_with_only_new_file: LocalOrgConfig, search_config: SearchConfig
|
||||
):
|
||||
def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: LocalOrgConfig):
|
||||
# Arrange
|
||||
# Ensure file mentioned in org.input-files is missing
|
||||
single_new_file = Path(org_config_with_only_new_file.input_files[0])
|
||||
|
@ -70,22 +68,39 @@ def test_text_search_setup_with_empty_file_raises_error(
|
|||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
||||
verify_embeddings(0, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup(content_config, default_user: KhojUser, caplog):
|
||||
def test_text_indexer_deletes_embedding_before_regenerate(
|
||||
content_config: ContentConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
with caplog.at_level(logging.INFO):
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert "Deleting all embeddings for file type org" in caplog.records[1].message
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
assert "Deleting all embeddings for file type org" in caplog.text
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_batch_processes(content_config: ContentConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert "Created 4 new embeddings" in caplog.text
|
||||
assert "Created 6 new embeddings" in caplog.text
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -97,13 +112,13 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
|
|||
|
||||
# Act
|
||||
# Generate initial notes embeddings during asymmetric setup
|
||||
with caplog.at_level(logging.INFO):
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
|
||||
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
|
||||
with caplog.at_level(logging.INFO):
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
|
||||
|
@ -175,12 +190,10 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
|
|||
|
||||
# Assert
|
||||
# verify newly added org-mode entry is split by max tokens
|
||||
record = caplog.records[1]
|
||||
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
|
||||
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
# @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests")
|
||||
@pytest.mark.django_db
|
||||
def test_entry_chunking_by_max_tokens_not_full_corpus(
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
|
@ -232,11 +245,9 @@ conda activate khoj
|
|||
user=default_user,
|
||||
)
|
||||
|
||||
record = caplog.records[1]
|
||||
|
||||
# Assert
|
||||
# verify newly added org-mode entry is split by max tokens
|
||||
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
|
||||
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -251,7 +262,7 @@ def test_regenerate_index_with_new_entry(
|
|||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
||||
|
||||
# append org-mode entry to first org input file in config
|
||||
org_config.input_files = [f"{new_org_file}"]
|
||||
|
@ -286,20 +297,23 @@ def test_update_index_with_duplicate_entries_in_stable_order(
|
|||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# Act
|
||||
# load embeddings, entries, notes model after adding new org-mode file
|
||||
# generate embeddings, entries, notes model from scratch after adding new org-mode file
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# update embeddings, entries, notes model after adding new org-mode file
|
||||
# update embeddings, entries, notes model with no new changes
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
|
||||
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in initial_logs
|
||||
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in final_logs
|
||||
|
||||
verify_embeddings(1, default_user)
|
||||
|
||||
|
@ -319,6 +333,8 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
|
|||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
|
||||
# update embeddings, entries, notes model after removing an entry from the org file
|
||||
with open(new_file_to_index, "w") as f:
|
||||
|
@ -329,11 +345,12 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
|
|||
# Act
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in caplog.records[4].message
|
||||
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in initial_logs
|
||||
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in final_logs
|
||||
|
||||
verify_embeddings(1, default_user)
|
||||
|
||||
|
@ -346,6 +363,8 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
|
|||
data = get_org_files(org_config)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
|
||||
# append org-mode entry to first org input file in config
|
||||
with open(new_org_file, "w") as f:
|
||||
|
@ -358,10 +377,11 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
|
|||
# update embeddings, entries with the newly added note
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in initial_logs
|
||||
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in final_logs
|
||||
|
||||
verify_embeddings(11, default_user)
|
||||
|
||||
|
|
Loading…
Reference in a new issue