mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
[Multi-User Part 6]: Address small bugs and upstream PR comments (#518)
-08654163cb
: Add better parsing for XML files -f3acfac7fb
: Add a try/catch around the dateparser in order to avoid internal server errors in app -7d43cd62c0
: Chunk embeddings generation in order to avoid large memory load -e02d751eb3
: Addresses comments from PR #498 -a3f393edb4
: Addresses comments from PR #503 -66eb078286
: Addresses comments from PR #511 - Address various items in https://github.com/khoj-ai/khoj/issues/527
This commit is contained in:
parent
5f3f6b7c61
commit
54a387326c
22 changed files with 264 additions and 114 deletions
|
@ -71,6 +71,7 @@ dependencies = [
|
||||||
"google-auth == 2.23.3",
|
"google-auth == 2.23.3",
|
||||||
"python-multipart == 0.0.6",
|
"python-multipart == 0.0.6",
|
||||||
"gunicorn == 21.2.0",
|
"gunicorn == 21.2.0",
|
||||||
|
"lxml == 4.9.3",
|
||||||
]
|
]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Type, TypeVar, List
|
from typing import Type, TypeVar, List
|
||||||
from datetime import date
|
from datetime import date
|
||||||
|
import secrets
|
||||||
|
from typing import Type, TypeVar, List
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.contrib.sessions.backends.db import SessionStore
|
from django.contrib.sessions.backends.db import SessionStore
|
||||||
|
@ -8,6 +11,10 @@ from pgvector.django import CosineDistance
|
||||||
from django.db.models.manager import BaseManager
|
from django.db.models.manager import BaseManager
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from pgvector.django import CosineDistance
|
||||||
|
from django.db.models.manager import BaseManager
|
||||||
|
from django.db.models import Q
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
# Import sync_to_async from Django Channels
|
# Import sync_to_async from Django Channels
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
|
@ -58,7 +65,7 @@ async def set_notion_config(token: str, user: KhojUser):
|
||||||
async def create_khoj_token(user: KhojUser, name=None):
|
async def create_khoj_token(user: KhojUser, name=None):
|
||||||
"Create Khoj API key for user"
|
"Create Khoj API key for user"
|
||||||
token = f"kk-{secrets.token_urlsafe(32)}"
|
token = f"kk-{secrets.token_urlsafe(32)}"
|
||||||
name = name or f"{generate_random_name().title()}'s Secret Key"
|
name = name or f"{generate_random_name().title()}"
|
||||||
api_config = await KhojApiUser.objects.acreate(token=token, user=user, name=name)
|
api_config = await KhojApiUser.objects.acreate(token=token, user=user, name=name)
|
||||||
await api_config.asave()
|
await api_config.asave()
|
||||||
return api_config
|
return api_config
|
||||||
|
@ -123,15 +130,11 @@ def get_all_users() -> BaseManager[KhojUser]:
|
||||||
|
|
||||||
def get_user_github_config(user: KhojUser):
|
def get_user_github_config(user: KhojUser):
|
||||||
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||||
if not config:
|
|
||||||
return None
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def get_user_notion_config(user: KhojUser):
|
def get_user_notion_config(user: KhojUser):
|
||||||
config = NotionConfig.objects.filter(user=user).first()
|
config = NotionConfig.objects.filter(user=user).first()
|
||||||
if not config:
|
|
||||||
return None
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@ -240,13 +243,10 @@ class ConversationAdapters:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_enabled_conversation_settings(user: KhojUser):
|
def get_enabled_conversation_settings(user: KhojUser):
|
||||||
openai_config = ConversationAdapters.get_openai_conversation_config(user)
|
openai_config = ConversationAdapters.get_openai_conversation_config(user)
|
||||||
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"openai": True if openai_config is not None else False,
|
"openai": True if openai_config is not None else False,
|
||||||
"offline_chat": True
|
"offline_chat": ConversationAdapters.has_offline_chat(user),
|
||||||
if (offline_chat_config is not None and offline_chat_config.enable_offline_chat)
|
|
||||||
else False,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -264,7 +264,11 @@ class ConversationAdapters:
|
||||||
OfflineChatProcessorConversationConfig.objects.filter(user=user).delete()
|
OfflineChatProcessorConversationConfig.objects.filter(user=user).delete()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def has_offline_chat(user: KhojUser):
|
def has_offline_chat(user: KhojUser):
|
||||||
|
return OfflineChatProcessorConversationConfig.objects.filter(user=user, enable_offline_chat=True).exists()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def ahas_offline_chat(user: KhojUser):
|
||||||
return await OfflineChatProcessorConversationConfig.objects.filter(
|
return await OfflineChatProcessorConversationConfig.objects.filter(
|
||||||
user=user, enable_offline_chat=True
|
user=user, enable_offline_chat=True
|
||||||
).aexists()
|
).aexists()
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
<head>
|
<head>
|
||||||
<meta charset="utf-8">
|
<meta charset="utf-8">
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
|
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
|
||||||
<title>Khoj - Search</title>
|
<title>Khoj - Settings</title>
|
||||||
|
|
||||||
<link rel="icon" type="image/png" sizes="128x128" href="./assets/icons/favicon-128x128.png">
|
<link rel="icon" type="image/png" sizes="128x128" href="./assets/icons/favicon-128x128.png">
|
||||||
<link rel="manifest" href="./khoj.webmanifest">
|
<link rel="manifest" href="./khoj.webmanifest">
|
||||||
|
|
|
@ -94,6 +94,15 @@
|
||||||
}).join("\n");
|
}).join("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function render_xml(query, data) {
|
||||||
|
return data.map(function (item) {
|
||||||
|
return `<div class="results-xml">` +
|
||||||
|
`<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` +
|
||||||
|
`<xml>${item.entry}</xml>` +
|
||||||
|
`</div>`
|
||||||
|
}).join("\n");
|
||||||
|
}
|
||||||
|
|
||||||
function render_multiple(query, data, type) {
|
function render_multiple(query, data, type) {
|
||||||
let html = "";
|
let html = "";
|
||||||
data.forEach(item => {
|
data.forEach(item => {
|
||||||
|
@ -113,6 +122,8 @@
|
||||||
html += `<div class="results-notion">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
|
html += `<div class="results-notion">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
|
||||||
} else if (item.additional.file.endsWith(".html")) {
|
} else if (item.additional.file.endsWith(".html")) {
|
||||||
html += render_html(query, [item]);
|
html += render_html(query, [item]);
|
||||||
|
} else if (item.additional.file.endsWith(".xml")) {
|
||||||
|
html += render_xml(query, [item])
|
||||||
} else {
|
} else {
|
||||||
html += `<div class="results-plugin">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
|
html += `<div class="results-plugin">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
|
||||||
}
|
}
|
||||||
|
|
|
@ -267,6 +267,12 @@ async function getFolders () {
|
||||||
}
|
}
|
||||||
|
|
||||||
async function setURL (event, url) {
|
async function setURL (event, url) {
|
||||||
|
// Sanitize the URL. Remove trailing slash if present. Add http:// if not present.
|
||||||
|
url = url.replace(/\/$/, "");
|
||||||
|
if (!url.match(/^[a-zA-Z]+:\/\//)) {
|
||||||
|
url = `http://${url}`;
|
||||||
|
}
|
||||||
|
|
||||||
store.set('hostURL', url);
|
store.set('hostURL', url);
|
||||||
return store.get('hostURL');
|
return store.get('hostURL');
|
||||||
}
|
}
|
||||||
|
|
|
@ -174,6 +174,7 @@ urlInput.addEventListener('blur', async () => {
|
||||||
new URL(urlInputValue);
|
new URL(urlInputValue);
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log(e);
|
console.log(e);
|
||||||
|
alert('Please enter a valid URL');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,6 @@ from khoj.utils import constants, state
|
||||||
from khoj.utils.config import (
|
from khoj.utils.config import (
|
||||||
SearchType,
|
SearchType,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import merge_dicts
|
|
||||||
from khoj.utils.fs_syncer import collect_files
|
from khoj.utils.fs_syncer import collect_files
|
||||||
from khoj.utils.rawconfig import FullConfig
|
from khoj.utils.rawconfig import FullConfig
|
||||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||||
|
@ -83,12 +82,6 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||||
|
|
||||||
|
|
||||||
def initialize_server(config: Optional[FullConfig]):
|
def initialize_server(config: Optional[FullConfig]):
|
||||||
if config is None:
|
|
||||||
logger.warning(
|
|
||||||
f"🚨 Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config, plugins or by editing {state.config_file}."
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
configure_server(config, init=True)
|
configure_server(config, init=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -190,7 +183,7 @@ def configure_search_types(config: FullConfig):
|
||||||
core_search_types = {e.name: e.value for e in SearchType}
|
core_search_types = {e.name: e.value for e in SearchType}
|
||||||
|
|
||||||
# Dynamically generate search type enum by merging core search types with configured plugin search types
|
# Dynamically generate search type enum by merging core search types with configured plugin search types
|
||||||
return Enum("SearchType", merge_dicts(core_search_types, {}))
|
return Enum("SearchType", core_search_types)
|
||||||
|
|
||||||
|
|
||||||
@schedule.repeat(schedule.every(59).minutes)
|
@schedule.repeat(schedule.every(59).minutes)
|
||||||
|
|
1
src/khoj/interface/web/assets/icons/copy-solid.svg
Normal file
1
src/khoj/interface/web/assets/icons/copy-solid.svg
Normal file
|
@ -0,0 +1 @@
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Pro 6.4.2 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc. --><path d="M208 0H332.1c12.7 0 24.9 5.1 33.9 14.1l67.9 67.9c9 9 14.1 21.2 14.1 33.9V336c0 26.5-21.5 48-48 48H208c-26.5 0-48-21.5-48-48V48c0-26.5 21.5-48 48-48zM48 128h80v64H64V448H256V416h64v48c0 26.5-21.5 48-48 48H48c-26.5 0-48-21.5-48-48V176c0-26.5 21.5-48 48-48z"/></svg>
|
After Width: | Height: | Size: 503 B |
1
src/khoj/interface/web/assets/icons/trash-solid.svg
Normal file
1
src/khoj/interface/web/assets/icons/trash-solid.svg
Normal file
|
@ -0,0 +1 @@
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Pro 6.4.2 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc. --><path d="M135.2 17.7L128 32H32C14.3 32 0 46.3 0 64S14.3 96 32 96H416c17.7 0 32-14.3 32-32s-14.3-32-32-32H320l-7.2-14.3C307.4 6.8 296.3 0 284.2 0H163.8c-12.1 0-23.2 6.8-28.6 17.7zM416 128H32L53.2 467c1.6 25.3 22.6 45 47.9 45H346.9c25.3 0 46.3-19.7 47.9-45L416 128z"/></svg>
|
After Width: | Height: | Size: 503 B |
|
@ -52,10 +52,24 @@
|
||||||
justify-self: center;
|
justify-self: center;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.api-settings {
|
||||||
div.section.general-settings {
|
display: grid;
|
||||||
justify-self: center;
|
grid-template-columns: 1fr;
|
||||||
}
|
grid-template-rows: 1fr 1fr auto;
|
||||||
|
justify-items: start;
|
||||||
|
gap: 8px;
|
||||||
|
padding: 24px 24px;
|
||||||
|
background: white;
|
||||||
|
border: 1px solid rgb(229, 229, 229);
|
||||||
|
border-radius: 4px;
|
||||||
|
box-shadow: 0px 1px 3px 0px rgba(0,0,0,0.1),0px 1px 2px -1px rgba(0,0,0,0.1);
|
||||||
|
}
|
||||||
|
#api-settings-card-description {
|
||||||
|
margin: 8px 0 0 0;
|
||||||
|
}
|
||||||
|
#api-settings-keys-table {
|
||||||
|
margin-bottom: 16px;
|
||||||
|
}
|
||||||
|
|
||||||
div.instructions {
|
div.instructions {
|
||||||
font-size: large;
|
font-size: large;
|
||||||
|
|
|
@ -287,11 +287,33 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="section general-settings">
|
<div class="section">
|
||||||
<div id="khoj-api-key-section" title="Use Khoj cloud with your Khoj API Key">
|
<div class="api-settings">
|
||||||
<button id="generate-api-key" onclick="generateAPIKey()">Generate API Key</button>
|
<div class="card-title-row">
|
||||||
<div id="api-key-list"></div>
|
<img class="card-icon" src="/static/assets/icons/key.svg" alt="API Key">
|
||||||
|
<h3 class="card-title">API Keys</h3>
|
||||||
|
</div>
|
||||||
|
<div class="card-description-row">
|
||||||
|
<p id="api-settings-card-description" class="card-description">Manage access to your Khoj from client apps</p>
|
||||||
|
</div>
|
||||||
|
<table id="api-settings-keys-table">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th scope="col">Name</th>
|
||||||
|
<th scope="col">Key</th>
|
||||||
|
<th scope="col">Actions</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody id="api-key-list"></tbody>
|
||||||
|
</table>
|
||||||
|
<div class="card-action-row">
|
||||||
|
<button class="card-button happy" id="generate-api-key" onclick="generateAPIKey()">
|
||||||
|
Generate API Key
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="section general-settings">
|
||||||
<div id="results-count" title="Number of items to show in search and use for chat response">
|
<div id="results-count" title="Number of items to show in search and use for chat response">
|
||||||
<label for="results-count-slider">Results Count: <span id="results-count-value">5</span></label>
|
<label for="results-count-slider">Results Count: <span id="results-count-value">5</span></label>
|
||||||
<input type="range" id="results-count-slider" name="results-count-slider" min="1" max="10" step="1" value="5">
|
<input type="range" id="results-count-slider" name="results-count-slider" min="1" max="10" step="1" value="5">
|
||||||
|
@ -520,14 +542,32 @@
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(tokenObj => {
|
.then(tokenObj => {
|
||||||
apiKeyList.innerHTML += `
|
apiKeyList.innerHTML += `
|
||||||
<div id="api-key-item-${tokenObj.token}" class="api-key-item">
|
<tr id="api-key-item-${tokenObj.token}">
|
||||||
<span class="api-key">${tokenObj.token}</span>
|
<td><b>${tokenObj.name}</b></td>
|
||||||
<button class="delete-api-key" onclick="deleteAPIKey('${tokenObj.token}')">Delete</button>
|
<td id="api-key-${tokenObj.token}">${tokenObj.token}</td>
|
||||||
</div>
|
<td>
|
||||||
|
<img id="api-key-copy-button-${tokenObj.token}" onclick="copyAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/copy-solid.svg" alt="Copy API Key">
|
||||||
|
<img id="api-key-delete-button-${tokenObj.token}" onclick="deleteAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/trash-solid.svg" alt="Delete API Key">
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
`;
|
`;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function copyAPIKey(token) {
|
||||||
|
// Copy API key to clipboard
|
||||||
|
navigator.clipboard.writeText(token);
|
||||||
|
// Flash the API key copied message
|
||||||
|
const copyApiKeyButton = document.getElementById(`api-key-${token}`);
|
||||||
|
original_html = copyApiKeyButton.innerHTML
|
||||||
|
setTimeout(function() {
|
||||||
|
copyApiKeyButton.innerHTML = "✅ Copied to your clipboard!";
|
||||||
|
setTimeout(function() {
|
||||||
|
copyApiKeyButton.innerHTML = original_html;
|
||||||
|
}, 1000);
|
||||||
|
}, 100);
|
||||||
|
}
|
||||||
|
|
||||||
function deleteAPIKey(token) {
|
function deleteAPIKey(token) {
|
||||||
const apiKeyList = document.getElementById("api-key-list");
|
const apiKeyList = document.getElementById("api-key-list");
|
||||||
fetch(`/auth/token?token=${token}`, {
|
fetch(`/auth/token?token=${token}`, {
|
||||||
|
@ -548,10 +588,14 @@
|
||||||
.then(tokens => {
|
.then(tokens => {
|
||||||
apiKeyList.innerHTML = tokens.map(tokenObj =>
|
apiKeyList.innerHTML = tokens.map(tokenObj =>
|
||||||
`
|
`
|
||||||
<div id="api-key-item-${tokenObj.token}" class="api-key-item">
|
<tr id="api-key-item-${tokenObj.token}">
|
||||||
<span class="api-key">${tokenObj.token}</span>
|
<td><b>${tokenObj.name}</b></td>
|
||||||
<button class="delete-api-key" onclick="deleteAPIKey('${tokenObj.token}')">Delete</button>
|
<td id="api-key-${tokenObj.token}">${tokenObj.token}</td>
|
||||||
</div>
|
<td>
|
||||||
|
<img id="api-key-copy-button-${tokenObj.token}" onclick="copyAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/copy-solid.svg" alt="Copy API Key">
|
||||||
|
<img id="api-key-delete-button-${tokenObj.token}" onclick="deleteAPIKey('${tokenObj.token}')" class="configured-icon enabled" src="/static/assets/icons/trash-solid.svg" alt="Delete API Key">
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
`)
|
`)
|
||||||
.join("");
|
.join("");
|
||||||
});
|
});
|
||||||
|
|
|
@ -94,6 +94,15 @@
|
||||||
}).join("\n");
|
}).join("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function render_xml(query, data) {
|
||||||
|
return data.map(function (item) {
|
||||||
|
return `<div class="results-xml">` +
|
||||||
|
`<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` +
|
||||||
|
`<xml>${item.entry}</xml>` +
|
||||||
|
`</div>`
|
||||||
|
}).join("\n");
|
||||||
|
}
|
||||||
|
|
||||||
function render_multiple(query, data, type) {
|
function render_multiple(query, data, type) {
|
||||||
let html = "";
|
let html = "";
|
||||||
data.forEach(item => {
|
data.forEach(item => {
|
||||||
|
@ -113,6 +122,8 @@
|
||||||
html += `<div class="results-notion">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
|
html += `<div class="results-notion">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
|
||||||
} else if (item.additional.file.endsWith(".html")) {
|
} else if (item.additional.file.endsWith(".html")) {
|
||||||
html += render_html(query, [item]);
|
html += render_html(query, [item]);
|
||||||
|
} else if (item.additional.file.endsWith(".xml")) {
|
||||||
|
html += render_xml(query, [item])
|
||||||
} else {
|
} else {
|
||||||
html += `<div class="results-plugin">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
|
html += `<div class="results-plugin">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
<head>
|
<head>
|
||||||
<meta charset="utf-8">
|
<meta charset="utf-8">
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
|
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
|
||||||
<title>Khoj - Search</title>
|
<title>Khoj - Login</title>
|
||||||
|
|
||||||
<link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png">
|
<link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png">
|
||||||
<link rel="manifest" href="/static/khoj.webmanifest">
|
<link rel="manifest" href="/static/khoj.webmanifest">
|
||||||
|
|
|
@ -2,12 +2,14 @@
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
from khoj.utils.rawconfig import Entry
|
||||||
from database.models import Embeddings, KhojUser, LocalPlaintextConfig
|
from database.models import Embeddings, KhojUser
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -28,6 +30,19 @@ class PlaintextToJsonl(TextEmbeddings):
|
||||||
else:
|
else:
|
||||||
deletion_file_names = None
|
deletion_file_names = None
|
||||||
|
|
||||||
|
with timer("Scrub plaintext files and extract text", logger):
|
||||||
|
for file in files:
|
||||||
|
try:
|
||||||
|
plaintext_content = files[file]
|
||||||
|
if file.endswith(("html", "htm", "xml")):
|
||||||
|
plaintext_content = PlaintextToJsonl.extract_html_content(
|
||||||
|
plaintext_content, file.split(".")[-1]
|
||||||
|
)
|
||||||
|
files[file] = plaintext_content
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Unable to read file: {file} as plaintext. Skipping file.")
|
||||||
|
logger.warning(e, exc_info=True)
|
||||||
|
|
||||||
# Extract Entries from specified plaintext files
|
# Extract Entries from specified plaintext files
|
||||||
with timer("Parse entries from plaintext files", logger):
|
with timer("Parse entries from plaintext files", logger):
|
||||||
current_entries = PlaintextToJsonl.convert_plaintext_entries_to_maps(files)
|
current_entries = PlaintextToJsonl.convert_plaintext_entries_to_maps(files)
|
||||||
|
@ -50,6 +65,15 @@ class PlaintextToJsonl(TextEmbeddings):
|
||||||
|
|
||||||
return num_new_embeddings, num_deleted_embeddings
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_html_content(markup_content: str, markup_type: str):
|
||||||
|
"Extract content from HTML"
|
||||||
|
if markup_type == "xml":
|
||||||
|
soup = BeautifulSoup(markup_content, "xml")
|
||||||
|
else:
|
||||||
|
soup = BeautifulSoup(markup_content, "html.parser")
|
||||||
|
return soup.get_text(strip=True, separator="\n")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
|
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
|
||||||
"Convert each plaintext entries into a dictionary"
|
"Convert each plaintext entries into a dictionary"
|
||||||
|
|
|
@ -5,7 +5,7 @@ import logging
|
||||||
import uuid
|
import uuid
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Callable, List, Tuple, Set, Any
|
from typing import Callable, List, Tuple, Set, Any
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer, batcher
|
||||||
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
|
@ -93,7 +93,7 @@ class TextEmbeddings(ABC):
|
||||||
num_deleted_embeddings = 0
|
num_deleted_embeddings = 0
|
||||||
with timer("Preparing dataset for regeneration", logger):
|
with timer("Preparing dataset for regeneration", logger):
|
||||||
if regenerate:
|
if regenerate:
|
||||||
logger.info(f"Deleting all embeddings for file type {file_type}")
|
logger.debug(f"Deleting all embeddings for file type {file_type}")
|
||||||
num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type)
|
num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type)
|
||||||
|
|
||||||
num_new_embeddings = 0
|
num_new_embeddings = 0
|
||||||
|
@ -106,48 +106,54 @@ class TextEmbeddings(ABC):
|
||||||
)
|
)
|
||||||
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
|
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
|
||||||
hashes_to_process = hashes_for_file - existing_entry_hashes
|
hashes_to_process = hashes_for_file - existing_entry_hashes
|
||||||
# for hashed_val in hashes_for_file:
|
|
||||||
# if not EmbeddingsAdapters.does_embedding_exist(user, hashed_val):
|
|
||||||
# hashes_to_process.add(hashed_val)
|
|
||||||
|
|
||||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||||
embeddings = self.embeddings_model.embed_documents(data_to_embed)
|
embeddings = self.embeddings_model.embed_documents(data_to_embed)
|
||||||
|
|
||||||
with timer("Update the database with new vector embeddings", logger):
|
with timer("Update the database with new vector embeddings", logger):
|
||||||
embeddings_to_create = []
|
num_items = len(hashes_to_process)
|
||||||
for hashed_val, embedding in zip(hashes_to_process, embeddings):
|
assert num_items == len(embeddings)
|
||||||
entry = hash_to_current_entries[hashed_val]
|
batch_size = min(200, num_items)
|
||||||
embeddings_to_create.append(
|
entry_batches = zip(hashes_to_process, embeddings)
|
||||||
Embeddings(
|
|
||||||
user=user,
|
|
||||||
embeddings=embedding,
|
|
||||||
raw=entry.raw,
|
|
||||||
compiled=entry.compiled,
|
|
||||||
heading=entry.heading,
|
|
||||||
file_path=entry.file,
|
|
||||||
file_type=file_type,
|
|
||||||
hashed_value=hashed_val,
|
|
||||||
corpus_id=entry.corpus_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
new_embeddings = Embeddings.objects.bulk_create(embeddings_to_create)
|
|
||||||
num_new_embeddings += len(new_embeddings)
|
|
||||||
|
|
||||||
dates_to_create = []
|
for entry_batch in tqdm(
|
||||||
with timer("Create new date associations for new embeddings", logger):
|
batcher(entry_batches, batch_size), desc="Processing embeddings in batches"
|
||||||
for embedding in new_embeddings:
|
):
|
||||||
dates = self.date_filter.extract_dates(embedding.raw)
|
batch_embeddings_to_create = []
|
||||||
for date in dates:
|
for entry_hash, embedding in entry_batch:
|
||||||
dates_to_create.append(
|
entry = hash_to_current_entries[entry_hash]
|
||||||
EmbeddingsDates(
|
batch_embeddings_to_create.append(
|
||||||
date=date,
|
Embeddings(
|
||||||
embeddings=embedding,
|
user=user,
|
||||||
)
|
embeddings=embedding,
|
||||||
|
raw=entry.raw,
|
||||||
|
compiled=entry.compiled,
|
||||||
|
heading=entry.heading[:1000], # Truncate to max chars of field allowed
|
||||||
|
file_path=entry.file,
|
||||||
|
file_type=file_type,
|
||||||
|
hashed_value=entry_hash,
|
||||||
|
corpus_id=entry.corpus_id,
|
||||||
)
|
)
|
||||||
new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create)
|
)
|
||||||
if len(new_dates) > 0:
|
new_embeddings = Embeddings.objects.bulk_create(batch_embeddings_to_create)
|
||||||
logger.info(f"Created {len(new_dates)} new date entries")
|
logger.debug(f"Created {len(new_embeddings)} new embeddings")
|
||||||
|
num_new_embeddings += len(new_embeddings)
|
||||||
|
|
||||||
|
dates_to_create = []
|
||||||
|
with timer("Create new date associations for new embeddings", logger):
|
||||||
|
for embedding in new_embeddings:
|
||||||
|
dates = self.date_filter.extract_dates(embedding.raw)
|
||||||
|
for date in dates:
|
||||||
|
dates_to_create.append(
|
||||||
|
EmbeddingsDates(
|
||||||
|
date=date,
|
||||||
|
embeddings=embedding,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create)
|
||||||
|
if len(new_dates) > 0:
|
||||||
|
logger.debug(f"Created {len(new_dates)} new date entries")
|
||||||
|
|
||||||
with timer("Identify hashes for removed entries", logger):
|
with timer("Identify hashes for removed entries", logger):
|
||||||
for file in hashes_by_file:
|
for file in hashes_by_file:
|
||||||
|
|
|
@ -723,7 +723,7 @@ async def extract_references_and_questions(
|
||||||
# Infer search queries from user message
|
# Infer search queries from user message
|
||||||
with timer("Extracting search queries took", logger):
|
with timer("Extracting search queries took", logger):
|
||||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||||
if await ConversationAdapters.has_offline_chat(user):
|
if await ConversationAdapters.ahas_offline_chat(user):
|
||||||
using_offline_chat = True
|
using_offline_chat = True
|
||||||
offline_chat = await ConversationAdapters.get_offline_chat(user)
|
offline_chat = await ConversationAdapters.get_offline_chat(user)
|
||||||
chat_model = offline_chat.chat_model
|
chat_model = offline_chat.chat_model
|
||||||
|
|
|
@ -46,7 +46,6 @@ async def login(request: Request):
|
||||||
return await oauth.google.authorize_redirect(request, redirect_uri)
|
return await oauth.google.authorize_redirect(request, redirect_uri)
|
||||||
|
|
||||||
|
|
||||||
@auth_router.post("/redirect")
|
|
||||||
@auth_router.post("/token")
|
@auth_router.post("/token")
|
||||||
@requires(["authenticated"], redirect="login_page")
|
@requires(["authenticated"], redirect="login_page")
|
||||||
async def generate_token(request: Request, token_name: Optional[str] = None) -> str:
|
async def generate_token(request: Request, token_name: Optional[str] = None) -> str:
|
||||||
|
|
|
@ -31,7 +31,7 @@ def perform_chat_checks(user: KhojUser):
|
||||||
|
|
||||||
|
|
||||||
async def is_ready_to_chat(user: KhojUser):
|
async def is_ready_to_chat(user: KhojUser):
|
||||||
has_offline_config = await ConversationAdapters.has_offline_chat(user=user)
|
has_offline_config = await ConversationAdapters.ahas_offline_chat(user=user)
|
||||||
has_openai_config = await ConversationAdapters.has_openai_chat(user=user)
|
has_openai_config = await ConversationAdapters.has_openai_chat(user=user)
|
||||||
|
|
||||||
if has_offline_config:
|
if has_offline_config:
|
||||||
|
|
|
@ -156,18 +156,17 @@ async def update(
|
||||||
host=host,
|
host=host,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(f"📪 Content index updated via API call by {client} client")
|
||||||
|
|
||||||
return Response(content="OK", status_code=200)
|
return Response(content="OK", status_code=200)
|
||||||
|
|
||||||
|
|
||||||
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
|
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
|
||||||
# Run Validation Checks
|
# Run Validation Checks
|
||||||
if search_config is None:
|
|
||||||
logger.warning("🚨 No Search configuration available.")
|
|
||||||
return None
|
|
||||||
if search_models is None:
|
if search_models is None:
|
||||||
search_models = SearchModels()
|
search_models = SearchModels()
|
||||||
|
|
||||||
if search_config.image:
|
if search_config and search_config.image:
|
||||||
logger.info("🔍 🌄 Setting up image search model")
|
logger.info("🔍 🌄 Setting up image search model")
|
||||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
|
|
||||||
|
|
|
@ -127,14 +127,18 @@ class DateFilter(BaseFilter):
|
||||||
clean_date_str = re.sub("|".join(future_strings), "", date_str)
|
clean_date_str = re.sub("|".join(future_strings), "", date_str)
|
||||||
|
|
||||||
# parse date passed in query date filter
|
# parse date passed in query date filter
|
||||||
parsed_date = dtparse.parse(
|
try:
|
||||||
clean_date_str,
|
parsed_date = dtparse.parse(
|
||||||
settings={
|
clean_date_str,
|
||||||
"RELATIVE_BASE": relative_base or datetime.now(),
|
settings={
|
||||||
"PREFER_DAY_OF_MONTH": "first",
|
"RELATIVE_BASE": relative_base or datetime.now(),
|
||||||
"PREFER_DATES_FROM": prefer_dates_from,
|
"PREFER_DAY_OF_MONTH": "first",
|
||||||
},
|
"PREFER_DATES_FROM": prefer_dates_from,
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse date string: {date_str} with error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
if parsed_date is None:
|
if parsed_date is None:
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -5,6 +5,7 @@ import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
from itertools import islice
|
||||||
import logging
|
import logging
|
||||||
from os import path
|
from os import path
|
||||||
import os
|
import os
|
||||||
|
@ -305,3 +306,13 @@ def generate_random_name():
|
||||||
name = f"{adjective} {noun}"
|
name = f"{adjective} {noun}"
|
||||||
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def batcher(iterable, max_n):
|
||||||
|
"Split an iterable into chunks of size max_n"
|
||||||
|
it = iter(iterable)
|
||||||
|
while True:
|
||||||
|
chunk = list(islice(it, max_n))
|
||||||
|
if not chunk:
|
||||||
|
return
|
||||||
|
yield (x for x in chunk if x is not None)
|
||||||
|
|
|
@ -22,9 +22,7 @@ logger = logging.getLogger(__name__)
|
||||||
# Test
|
# Test
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_text_search_setup_with_missing_file_raises_error(
|
def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: LocalOrgConfig):
|
||||||
org_config_with_only_new_file: LocalOrgConfig, search_config: SearchConfig
|
|
||||||
):
|
|
||||||
# Arrange
|
# Arrange
|
||||||
# Ensure file mentioned in org.input-files is missing
|
# Ensure file mentioned in org.input-files is missing
|
||||||
single_new_file = Path(org_config_with_only_new_file.input_files[0])
|
single_new_file = Path(org_config_with_only_new_file.input_files[0])
|
||||||
|
@ -70,22 +68,39 @@ def test_text_search_setup_with_empty_file_raises_error(
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.INFO):
|
||||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
|
|
||||||
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
||||||
verify_embeddings(0, default_user)
|
verify_embeddings(0, default_user)
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_text_search_setup(content_config, default_user: KhojUser, caplog):
|
def test_text_indexer_deletes_embedding_before_regenerate(
|
||||||
|
content_config: ContentConfig, default_user: KhojUser, caplog
|
||||||
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||||
data = get_org_files(org_config)
|
data = get_org_files(org_config)
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.DEBUG):
|
||||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert "Deleting all embeddings for file type org" in caplog.records[1].message
|
assert "Deleting all embeddings for file type org" in caplog.text
|
||||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_text_search_setup_batch_processes(content_config: ContentConfig, default_user: KhojUser, caplog):
|
||||||
|
# Arrange
|
||||||
|
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||||
|
data = get_org_files(org_config)
|
||||||
|
with caplog.at_level(logging.DEBUG):
|
||||||
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert "Created 4 new embeddings" in caplog.text
|
||||||
|
assert "Created 6 new embeddings" in caplog.text
|
||||||
|
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@ -97,13 +112,13 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Generate initial notes embeddings during asymmetric setup
|
# Generate initial notes embeddings during asymmetric setup
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.DEBUG):
|
||||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
initial_logs = caplog.text
|
initial_logs = caplog.text
|
||||||
caplog.clear() # Clear logs
|
caplog.clear() # Clear logs
|
||||||
|
|
||||||
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
|
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.DEBUG):
|
||||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||||
final_logs = caplog.text
|
final_logs = caplog.text
|
||||||
|
|
||||||
|
@ -175,12 +190,10 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify newly added org-mode entry is split by max tokens
|
# verify newly added org-mode entry is split by max tokens
|
||||||
record = caplog.records[1]
|
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
|
||||||
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
# @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests")
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_entry_chunking_by_max_tokens_not_full_corpus(
|
def test_entry_chunking_by_max_tokens_not_full_corpus(
|
||||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||||
|
@ -232,11 +245,9 @@ conda activate khoj
|
||||||
user=default_user,
|
user=default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
record = caplog.records[1]
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify newly added org-mode entry is split by max tokens
|
# verify newly added org-mode entry is split by max tokens
|
||||||
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
|
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@ -251,7 +262,7 @@ def test_regenerate_index_with_new_entry(
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.INFO):
|
||||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
|
|
||||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
||||||
|
|
||||||
# append org-mode entry to first org input file in config
|
# append org-mode entry to first org input file in config
|
||||||
org_config.input_files = [f"{new_org_file}"]
|
org_config.input_files = [f"{new_org_file}"]
|
||||||
|
@ -286,20 +297,23 @@ def test_update_index_with_duplicate_entries_in_stable_order(
|
||||||
data = get_org_files(org_config_with_only_new_file)
|
data = get_org_files(org_config_with_only_new_file)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# load embeddings, entries, notes model after adding new org-mode file
|
# generate embeddings, entries, notes model from scratch after adding new org-mode file
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.INFO):
|
||||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
|
initial_logs = caplog.text
|
||||||
|
caplog.clear() # Clear logs
|
||||||
|
|
||||||
data = get_org_files(org_config_with_only_new_file)
|
data = get_org_files(org_config_with_only_new_file)
|
||||||
|
|
||||||
# update embeddings, entries, notes model after adding new org-mode file
|
# update embeddings, entries, notes model with no new changes
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.INFO):
|
||||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||||
|
final_logs = caplog.text
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify only 1 entry added even if there are multiple duplicate entries
|
# verify only 1 entry added even if there are multiple duplicate entries
|
||||||
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in initial_logs
|
||||||
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
|
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in final_logs
|
||||||
|
|
||||||
verify_embeddings(1, default_user)
|
verify_embeddings(1, default_user)
|
||||||
|
|
||||||
|
@ -319,6 +333,8 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
|
||||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.INFO):
|
||||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
|
initial_logs = caplog.text
|
||||||
|
caplog.clear() # Clear logs
|
||||||
|
|
||||||
# update embeddings, entries, notes model after removing an entry from the org file
|
# update embeddings, entries, notes model after removing an entry from the org file
|
||||||
with open(new_file_to_index, "w") as f:
|
with open(new_file_to_index, "w") as f:
|
||||||
|
@ -329,11 +345,12 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
|
||||||
# Act
|
# Act
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.INFO):
|
||||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||||
|
final_logs = caplog.text
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify only 1 entry added even if there are multiple duplicate entries
|
# verify only 1 entry added even if there are multiple duplicate entries
|
||||||
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in initial_logs
|
||||||
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in caplog.records[4].message
|
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in final_logs
|
||||||
|
|
||||||
verify_embeddings(1, default_user)
|
verify_embeddings(1, default_user)
|
||||||
|
|
||||||
|
@ -346,6 +363,8 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
|
||||||
data = get_org_files(org_config)
|
data = get_org_files(org_config)
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.INFO):
|
||||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
|
initial_logs = caplog.text
|
||||||
|
caplog.clear() # Clear logs
|
||||||
|
|
||||||
# append org-mode entry to first org input file in config
|
# append org-mode entry to first org input file in config
|
||||||
with open(new_org_file, "w") as f:
|
with open(new_org_file, "w") as f:
|
||||||
|
@ -358,10 +377,11 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
|
||||||
# update embeddings, entries with the newly added note
|
# update embeddings, entries with the newly added note
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.INFO):
|
||||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||||
|
final_logs = caplog.text
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in initial_logs
|
||||||
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
|
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in final_logs
|
||||||
|
|
||||||
verify_embeddings(11, default_user)
|
verify_embeddings(11, default_user)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue