mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Add a plugin which allows users to index their Notion pages (#284)
* For the demo instance, re-instate the scheduler, but infrequently for api updates - In constants, determine the cadence based on whether it's a demo instance or not - This allow us to collect telemetry again. This will also allow us to save the chat session * Conditionally skip updating the index altogether if it's a demo isntance * Add backend support for Notion data parsing - Add a NotionToJsonl class which parses the text of Notion documents made accessible to the API token - Make corresponding updates to the default config, raw config to support the new notion addition * Add corresponding views to support configuring Notion from the web-based settings page - Support backend APIs for deleting/configuring notion setup as well - Streamline some of the index updating code * Use defaults for search and chat queries results count * Update pagination of retrieving pages from Notion * Update state conversation processor when update is hit * frequency_penalty should be passed to gpt through kwargs * Add check for notion in render_multiple method * Add headings to Notion render * Revert results count slider and split Notion files by blocks * Clean/fix misc things in the function to update index - Use the successText and errorText variables appropriately - Name parameters in function calls - Add emojis, woohoo * Clean up and further modularize code for processing data in Notion
This commit is contained in:
parent
77755c0284
commit
62704cac09
15 changed files with 520 additions and 64 deletions
|
@ -17,6 +17,7 @@ from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
|
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
|
||||||
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||||
|
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
|
from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
|
||||||
|
@ -169,6 +170,18 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize Notion Search
|
||||||
|
if (t == None or t in state.SearchType) and config.content_type.notion:
|
||||||
|
logger.info("🔌 Setting up search for notion")
|
||||||
|
model.notion_search = text_search.setup(
|
||||||
|
NotionToJsonl,
|
||||||
|
config.content_type.notion,
|
||||||
|
search_config=config.search_type.asymmetric,
|
||||||
|
regenerate=regenerate,
|
||||||
|
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("🚨 Failed to setup search")
|
logger.error("🚨 Failed to setup search")
|
||||||
raise e
|
raise e
|
||||||
|
@ -248,7 +261,7 @@ def save_chat_session():
|
||||||
|
|
||||||
@schedule.repeat(schedule.every(59).minutes)
|
@schedule.repeat(schedule.every(59).minutes)
|
||||||
def upload_telemetry():
|
def upload_telemetry():
|
||||||
if not state.config or not state.config.app.should_log_telemetry or not state.telemetry:
|
if not state.config or not state.config.app or not state.config.app.should_log_telemetry or not state.telemetry:
|
||||||
message = "📡 No telemetry to upload" if not state.telemetry else "📡 Telemetry logging disabled"
|
message = "📡 No telemetry to upload" if not state.telemetry else "📡 Telemetry logging disabled"
|
||||||
logger.debug(message)
|
logger.debug(message)
|
||||||
return
|
return
|
||||||
|
|
4
src/khoj/interface/web/assets/icons/notion.svg
Normal file
4
src/khoj/interface/web/assets/icons/notion.svg
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
<svg width="100" height="100" viewBox="0 0 100 100" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<path d="M6.017 4.313l55.333 -4.087c6.797 -0.583 8.543 -0.19 12.817 2.917l17.663 12.443c2.913 2.14 3.883 2.723 3.883 5.053v68.243c0 4.277 -1.553 6.807 -6.99 7.193L24.467 99.967c-4.08 0.193 -6.023 -0.39 -8.16 -3.113L3.3 79.94c-2.333 -3.113 -3.3 -5.443 -3.3 -8.167V11.113c0 -3.497 1.553 -6.413 6.017 -6.8z" fill="#fff"/>
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M61.35 0.227l-55.333 4.087C1.553 4.7 0 7.617 0 11.113v60.66c0 2.723 0.967 5.053 3.3 8.167l13.007 16.913c2.137 2.723 4.08 3.307 8.16 3.113l64.257 -3.89c5.433 -0.387 6.99 -2.917 6.99 -7.193V20.64c0 -2.21 -0.873 -2.847 -3.443 -4.733L74.167 3.143c-4.273 -3.107 -6.02 -3.5 -12.817 -2.917zM25.92 19.523c-5.247 0.353 -6.437 0.433 -9.417 -1.99L8.927 11.507c-0.77 -0.78 -0.383 -1.753 1.557 -1.947l53.193 -3.887c4.467 -0.39 6.793 1.167 8.54 2.527l9.123 6.61c0.39 0.197 1.36 1.36 0.193 1.36l-54.933 3.307 -0.68 0.047zM19.803 88.3V30.367c0 -2.53 0.777 -3.697 3.103 -3.893L86 22.78c2.14 -0.193 3.107 1.167 3.107 3.693v57.547c0 2.53 -0.39 4.67 -3.883 4.863l-60.377 3.5c-3.493 0.193 -5.043 -0.97 -5.043 -4.083zm59.6 -54.827c0.387 1.75 0 3.5 -1.75 3.7l-2.91 0.577v42.773c-2.527 1.36 -4.853 2.137 -6.797 2.137 -3.107 0 -3.883 -0.973 -6.21 -3.887l-19.03 -29.94v28.967l6.02 1.363s0 3.5 -4.857 3.5l-13.39 0.777c-0.39 -0.78 0 -2.723 1.357 -3.11l3.497 -0.97v-38.3L30.48 40.667c-0.39 -1.75 0.58 -4.277 3.3 -4.473l14.367 -0.967 19.8 30.327v-26.83l-5.047 -0.58c-0.39 -2.143 1.163 -3.7 3.103 -3.89l13.4 -0.78z" fill="#000"/>
|
||||||
|
</svg>
|
After Width: | Height: | Size: 1.5 KiB |
|
@ -93,6 +93,7 @@
|
||||||
|
|
||||||
// Decode message chunk from stream
|
// Decode message chunk from stream
|
||||||
const chunk = decoder.decode(value, { stream: true });
|
const chunk = decoder.decode(value, { stream: true });
|
||||||
|
|
||||||
if (chunk.includes("### compiled references:")) {
|
if (chunk.includes("### compiled references:")) {
|
||||||
const additionalResponse = chunk.split("### compiled references:")[0];
|
const additionalResponse = chunk.split("### compiled references:")[0];
|
||||||
new_response_text.innerHTML += additionalResponse;
|
new_response_text.innerHTML += additionalResponse;
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
<img id="configured-icon-github" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
<img id="configured-icon-github" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||||
{% endif %}
|
{% endif %}
|
||||||
</h3>
|
</h3>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class="card-description-row">
|
<div class="card-description-row">
|
||||||
<p class="card-description">Set repositories for Khoj to index</p>
|
<p class="card-description">Set repositories for Khoj to index</p>
|
||||||
|
@ -37,6 +36,37 @@
|
||||||
</div>
|
</div>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
</div>
|
</div>
|
||||||
|
<div class="card">
|
||||||
|
<div class="card-title-row">
|
||||||
|
<img class="card-icon" src="/static/assets/icons/notion.svg" alt="Notion">
|
||||||
|
<h3 class="card-title">
|
||||||
|
Notion
|
||||||
|
{% if current_config.content_type.notion %}
|
||||||
|
<img id="configured-icon-notion" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||||
|
{% endif %}
|
||||||
|
</h3>
|
||||||
|
</div>
|
||||||
|
<div class="card-description-row">
|
||||||
|
<p class="card-description">Configure your settings from Notion</p>
|
||||||
|
</div>
|
||||||
|
<div class="card-action-row">
|
||||||
|
<a class="card-button" href="/config/content_type/notion">
|
||||||
|
{% if current_config.content_type.content %}
|
||||||
|
Update
|
||||||
|
{% else %}
|
||||||
|
Setup
|
||||||
|
{% endif %}
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
{% if current_config.content_type.notion %}
|
||||||
|
<div id="clear-notion" class="card-action-row">
|
||||||
|
<button class="card-button" onclick="clearContentType('notion')">
|
||||||
|
Disable
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
<div class="card">
|
<div class="card">
|
||||||
<div class="card-title-row">
|
<div class="card-title-row">
|
||||||
<img class="card-icon" src="/static/assets/icons/markdown.svg" alt="markdown">
|
<img class="card-icon" src="/static/assets/icons/markdown.svg" alt="markdown">
|
||||||
|
@ -224,40 +254,32 @@
|
||||||
var configure = document.getElementById("configure");
|
var configure = document.getElementById("configure");
|
||||||
configure.addEventListener("click", function(event) {
|
configure.addEventListener("click", function(event) {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
configure.disabled = true;
|
updateIndex(
|
||||||
configure.innerHTML = "Configuring...";
|
force=false,
|
||||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
successText="Configured successfully!",
|
||||||
fetch('/api/update?&client=web', {
|
errorText="Unable to configure. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
|
||||||
method: 'GET',
|
button=configure,
|
||||||
headers: {
|
loadingText="Configuring...",
|
||||||
'Content-Type': 'application/json',
|
emoji="⚙️");
|
||||||
'X-CSRFToken': csrfToken
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.then(response => response.json())
|
|
||||||
.then(data => {
|
|
||||||
console.log('Success:', data);
|
|
||||||
document.getElementById("status").innerHTML = "Configured successfully!";
|
|
||||||
document.getElementById("status").style.display = "block";
|
|
||||||
configure.disabled = false;
|
|
||||||
configure.innerHTML = "⚙️ Configured";
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
console.error('Error:', error);
|
|
||||||
document.getElementById("status").innerHTML = "Unable to save configuration. Raise issue on Khoj Discord or Github.";
|
|
||||||
document.getElementById("status").style.display = "block";
|
|
||||||
configure.disabled = false;
|
|
||||||
configure.innerHTML = "⚙️ Configure";
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
||||||
var reinitialize = document.getElementById("reinitialize");
|
var reinitialize = document.getElementById("reinitialize");
|
||||||
reinitialize.addEventListener("click", function(event) {
|
reinitialize.addEventListener("click", function(event) {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
reinitialize.disabled = true;
|
updateIndex(
|
||||||
reinitialize.innerHTML = "Reinitializing...";
|
force=true,
|
||||||
|
successText="Reinitialized successfully!",
|
||||||
|
errorText="Unable to reinitialize. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
|
||||||
|
button=reinitialize,
|
||||||
|
loadingText="Reinitializing...",
|
||||||
|
emoji="🔄");
|
||||||
|
});
|
||||||
|
|
||||||
|
function updateIndex(force, successText, errorText, button, loadingText, emoji) {
|
||||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||||
fetch('/api/update?&client=web&force=True', {
|
button.disabled = true;
|
||||||
|
button.innerHTML = emoji + loadingText;
|
||||||
|
fetch('/api/update?&client=web&force=' + force, {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
@ -267,19 +289,22 @@
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
console.log('Success:', data);
|
console.log('Success:', data);
|
||||||
document.getElementById("status").innerHTML = "Reinitialized successfully!";
|
if (data.detail != null) {
|
||||||
|
throw new Error(data.detail);
|
||||||
|
}
|
||||||
|
document.getElementById("status").innerHTML = emoji + successText;
|
||||||
document.getElementById("status").style.display = "block";
|
document.getElementById("status").style.display = "block";
|
||||||
reinitialize.disabled = false;
|
button.disabled = false;
|
||||||
reinitialize.innerHTML = "🔄 Reinitialized";
|
button.innerHTML = '✅ Done!';
|
||||||
})
|
})
|
||||||
.catch((error) => {
|
.catch((error) => {
|
||||||
console.error('Error:', error);
|
console.error('Error:', error);
|
||||||
document.getElementById("status").innerHTML = "Unable to reinitialize. Raise issue on Khoj Discord or Github.";
|
document.getElementById("status").innerHTML = emoji + errorText
|
||||||
document.getElementById("status").style.display = "block";
|
document.getElementById("status").style.display = "block";
|
||||||
reinitialize.disabled = false;
|
button.disabled = false;
|
||||||
reinitialize.innerHTML = "🔄 Reinitialize";
|
button.innerHTML = '⚠️ Unsuccessful';
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Setup the results count slider
|
// Setup the results count slider
|
||||||
const resultsCountSlider = document.getElementById('results-count-slider');
|
const resultsCountSlider = document.getElementById('results-count-slider');
|
||||||
|
|
86
src/khoj/interface/web/content_type_notion_input.html
Normal file
86
src/khoj/interface/web/content_type_notion_input.html
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
{% extends "base_config.html" %}
|
||||||
|
{% block content %}
|
||||||
|
<div class="page">
|
||||||
|
<div class="section">
|
||||||
|
<h2 class="section-title">
|
||||||
|
<img class="card-icon" src="/static/assets/icons/notion.svg" alt="Notion">
|
||||||
|
<span class="card-title-text">Notion</span>
|
||||||
|
</h2>
|
||||||
|
<form>
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<td>
|
||||||
|
<label for="token">Token</label>
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
<input type="text" id="token" name="pat" value="{{ current_config['token'] }}">
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
<table style="display: none;" >
|
||||||
|
<tr>
|
||||||
|
<td>
|
||||||
|
<label for="compressed-jsonl">Compressed JSONL (Output)</label>
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
<input type="text" id="compressed-jsonl" name="compressed-jsonl" value="{{ current_config['compressed_jsonl'] }}">
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>
|
||||||
|
<label for="embeddings-file">Embeddings File (Output)</label>
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
<input type="text" id="embeddings-file" name="embeddings-file" value="{{ current_config['embeddings_file'] }}">
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
<div class="section">
|
||||||
|
<div id="success" style="display: none;"></div>
|
||||||
|
<button id="submit" type="submit">Save</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<script>
|
||||||
|
const submit = document.getElementById("submit");
|
||||||
|
|
||||||
|
submit.addEventListener("click", function(event) {
|
||||||
|
event.preventDefault();
|
||||||
|
|
||||||
|
const compressed_jsonl = document.getElementById("compressed-jsonl").value;
|
||||||
|
const embeddings_file = document.getElementById("embeddings-file").value;
|
||||||
|
const token = document.getElementById("token").value;
|
||||||
|
|
||||||
|
if (token == "") {
|
||||||
|
document.getElementById("success").innerHTML = "❌ Please enter a Notion Token.";
|
||||||
|
document.getElementById("success").style.display = "block";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||||
|
fetch('/api/config/data/content_type/notion', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'X-CSRFToken': csrfToken,
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
"token": token,
|
||||||
|
"compressed_jsonl": compressed_jsonl,
|
||||||
|
"embeddings_file": embeddings_file,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => {
|
||||||
|
if (data["status"] == "ok") {
|
||||||
|
document.getElementById("success").innerHTML = "✅ Successfully updated. Go to your <a href='/config'>settings page</a> to complete setup.";
|
||||||
|
document.getElementById("success").style.display = "block";
|
||||||
|
} else {
|
||||||
|
document.getElementById("success").innerHTML = "⚠️ Failed to update settings.";
|
||||||
|
document.getElementById("success").style.display = "block";
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
{% endblock %}
|
|
@ -71,6 +71,8 @@
|
||||||
html += render_markdown(query, [item]);
|
html += render_markdown(query, [item]);
|
||||||
} else if (item.additional.file.endsWith(".pdf")) {
|
} else if (item.additional.file.endsWith(".pdf")) {
|
||||||
html += render_pdf(query, [item]);
|
html += render_pdf(query, [item]);
|
||||||
|
} else if (item.additional.file.includes("notion.so")) {
|
||||||
|
html += `<div class="results-notion">` + `<b><a href="${item.additional.file}">${item.additional.heading}</a></b>` + `<p>${item.entry}</p>` + `</div>`;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
return html;
|
return html;
|
||||||
|
@ -86,7 +88,7 @@
|
||||||
results = data.map(render_image).join('');
|
results = data.map(render_image).join('');
|
||||||
} else if (type === "pdf") {
|
} else if (type === "pdf") {
|
||||||
results = render_pdf(query, data);
|
results = render_pdf(query, data);
|
||||||
} else if (type === "github" || type === "all") {
|
} else if (type === "github" || type === "all" || type === "notion") {
|
||||||
results = render_multiple(query, data, type);
|
results = render_multiple(query, data, type);
|
||||||
} else {
|
} else {
|
||||||
results = data.map((item) => `<div class="results-plugin">` + `<p>${item.entry}</p>` + `</div>`).join("\n")
|
results = data.map((item) => `<div class="results-plugin">` + `<p>${item.entry}</p>` + `</div>`).join("\n")
|
||||||
|
@ -127,7 +129,7 @@
|
||||||
setQueryFieldInUrl(query);
|
setQueryFieldInUrl(query);
|
||||||
|
|
||||||
// Execute Search and Render Results
|
// Execute Search and Render Results
|
||||||
url = createRequestUrl(query, type, results_count, rerank);
|
url = createRequestUrl(query, type, results_count || 5, rerank);
|
||||||
fetch(url)
|
fetch(url)
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
|
@ -347,6 +349,7 @@
|
||||||
white-space: pre-wrap;
|
white-space: pre-wrap;
|
||||||
}
|
}
|
||||||
.results-pdf,
|
.results-pdf,
|
||||||
|
.results-notion,
|
||||||
.results-plugin {
|
.results-plugin {
|
||||||
text-align: left;
|
text-align: left;
|
||||||
white-space: pre-line;
|
white-space: pre-line;
|
||||||
|
@ -404,6 +407,7 @@
|
||||||
|
|
||||||
div#results-error,
|
div#results-error,
|
||||||
div.results-markdown,
|
div.results-markdown,
|
||||||
|
div.results-notion,
|
||||||
div.results-org,
|
div.results-org,
|
||||||
div.results-pdf {
|
div.results-pdf {
|
||||||
text-align: left;
|
text-align: left;
|
||||||
|
|
|
@ -32,8 +32,7 @@ def summarize(session, model, api_key=None, temperature=0.5, max_tokens=200):
|
||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
frequency_penalty=0.2,
|
model_kwargs={"stop": ['"""'], "frequency_penalty": 0.2},
|
||||||
model_kwargs={"stop": ['"""']},
|
|
||||||
openai_api_key=api_key,
|
openai_api_key=api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
243
src/khoj/processor/notion/notion_to_jsonl.py
Normal file
243
src/khoj/processor/notion/notion_to_jsonl.py
Normal file
|
@ -0,0 +1,243 @@
|
||||||
|
# Standard Packages
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# External Packages
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
|
from khoj.utils.helpers import timer
|
||||||
|
from khoj.utils.rawconfig import Entry, NotionContentConfig
|
||||||
|
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||||
|
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||||
|
from khoj.utils.rawconfig import Entry
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NotionBlockType(Enum):
|
||||||
|
PARAGRAPH = "paragraph"
|
||||||
|
HEADING_1 = "heading_1"
|
||||||
|
HEADING_2 = "heading_2"
|
||||||
|
HEADING_3 = "heading_3"
|
||||||
|
BULLETED_LIST_ITEM = "bulleted_list_item"
|
||||||
|
NUMBERED_LIST_ITEM = "numbered_list_item"
|
||||||
|
TO_DO = "to_do"
|
||||||
|
TOGGLE = "toggle"
|
||||||
|
CHILD_PAGE = "child_page"
|
||||||
|
UNSUPPORTED = "unsupported"
|
||||||
|
BOOKMARK = "bookmark"
|
||||||
|
DIVIDER = "divider"
|
||||||
|
PDF = "pdf"
|
||||||
|
IMAGE = "image"
|
||||||
|
EMBED = "embed"
|
||||||
|
VIDEO = "video"
|
||||||
|
FILE = "file"
|
||||||
|
SYNCED_BLOCK = "synced_block"
|
||||||
|
TABLE_OF_CONTENTS = "table_of_contents"
|
||||||
|
COLUMN = "column"
|
||||||
|
EQUATION = "equation"
|
||||||
|
LINK_PREVIEW = "link_preview"
|
||||||
|
COLUMN_LIST = "column_list"
|
||||||
|
QUOTE = "quote"
|
||||||
|
BREADCRUMB = "breadcrumb"
|
||||||
|
LINK_TO_PAGE = "link_to_page"
|
||||||
|
CHILD_DATABASE = "child_database"
|
||||||
|
TEMPLATE = "template"
|
||||||
|
CALLOUT = "callout"
|
||||||
|
|
||||||
|
|
||||||
|
class NotionToJsonl(TextToJsonl):
|
||||||
|
def __init__(self, config: NotionContentConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.session = requests.Session()
|
||||||
|
self.session.headers.update({"Authorization": f"Bearer {config.token}", "Notion-Version": "2022-02-22"})
|
||||||
|
self.unsupported_block_types = [
|
||||||
|
NotionBlockType.BOOKMARK.value,
|
||||||
|
NotionBlockType.DIVIDER.value,
|
||||||
|
NotionBlockType.CHILD_DATABASE.value,
|
||||||
|
NotionBlockType.TEMPLATE.value,
|
||||||
|
NotionBlockType.CALLOUT.value,
|
||||||
|
NotionBlockType.UNSUPPORTED.value,
|
||||||
|
]
|
||||||
|
|
||||||
|
self.display_block_block_types = [
|
||||||
|
NotionBlockType.PARAGRAPH.value,
|
||||||
|
NotionBlockType.HEADING_1.value,
|
||||||
|
NotionBlockType.HEADING_2.value,
|
||||||
|
NotionBlockType.HEADING_3.value,
|
||||||
|
NotionBlockType.BULLETED_LIST_ITEM.value,
|
||||||
|
NotionBlockType.NUMBERED_LIST_ITEM.value,
|
||||||
|
NotionBlockType.TO_DO.value,
|
||||||
|
NotionBlockType.TOGGLE.value,
|
||||||
|
NotionBlockType.CHILD_PAGE.value,
|
||||||
|
NotionBlockType.BOOKMARK.value,
|
||||||
|
NotionBlockType.DIVIDER.value,
|
||||||
|
]
|
||||||
|
|
||||||
|
def process(self, previous_entries=None):
|
||||||
|
current_entries = []
|
||||||
|
|
||||||
|
# Get all pages
|
||||||
|
with timer("Getting all pages via search endpoint", logger=logger):
|
||||||
|
responses = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
result = self.session.post(
|
||||||
|
"https://api.notion.com/v1/search",
|
||||||
|
json={"page_size": 100},
|
||||||
|
).json()
|
||||||
|
responses.append(result)
|
||||||
|
if result["has_more"] == False:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.session.params = {"start_cursor": responses[-1]["next_cursor"]}
|
||||||
|
|
||||||
|
for response in responses:
|
||||||
|
with timer("Processing response", logger=logger):
|
||||||
|
pages_or_databases = response["results"]
|
||||||
|
|
||||||
|
# Get all pages content
|
||||||
|
for p_or_d in pages_or_databases:
|
||||||
|
with timer(f"Processing {p_or_d['object']} {p_or_d['id']}", logger=logger):
|
||||||
|
if p_or_d["object"] == "database":
|
||||||
|
# TODO: Handle databases
|
||||||
|
continue
|
||||||
|
elif p_or_d["object"] == "page":
|
||||||
|
page_entries = self.process_page(p_or_d)
|
||||||
|
current_entries.extend(page_entries)
|
||||||
|
|
||||||
|
return self.update_entries_with_ids(current_entries, previous_entries)
|
||||||
|
|
||||||
|
def process_page(self, page):
|
||||||
|
page_id = page["id"]
|
||||||
|
title, content = self.get_page_content(page_id)
|
||||||
|
|
||||||
|
if title == None or content == None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
current_entries = []
|
||||||
|
curr_heading = ""
|
||||||
|
for block in content["results"]:
|
||||||
|
block_type = block.get("type")
|
||||||
|
|
||||||
|
if block_type == None:
|
||||||
|
continue
|
||||||
|
block_data = block[block_type]
|
||||||
|
|
||||||
|
if block_data.get("rich_text") == None or len(block_data["rich_text"]) == 0:
|
||||||
|
# There's no text to handle here.
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_content = ""
|
||||||
|
if block_type in ["heading_1", "heading_2", "heading_3"]:
|
||||||
|
# If the current block is a heading, we can consider the previous block processing completed.
|
||||||
|
# Add it as an entry and move on to processing the next chunk of the page.
|
||||||
|
if raw_content != "":
|
||||||
|
current_entries.append(
|
||||||
|
Entry(
|
||||||
|
compiled=raw_content,
|
||||||
|
raw=raw_content,
|
||||||
|
heading=title,
|
||||||
|
file=page["url"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
curr_heading = block_data["rich_text"][0]["plain_text"]
|
||||||
|
else:
|
||||||
|
if curr_heading != "":
|
||||||
|
# Add the last known heading to the content for additional context
|
||||||
|
raw_content = self.process_heading(curr_heading)
|
||||||
|
for text in block_data["rich_text"]:
|
||||||
|
raw_content += self.process_text(text)
|
||||||
|
|
||||||
|
if block.get("has_children", True):
|
||||||
|
raw_content += "\n"
|
||||||
|
raw_content = self.process_nested_children(
|
||||||
|
self.get_block_children(block["id"]), raw_content, block_type
|
||||||
|
)
|
||||||
|
|
||||||
|
if raw_content != "":
|
||||||
|
current_entries.append(
|
||||||
|
Entry(
|
||||||
|
compiled=raw_content,
|
||||||
|
raw=raw_content,
|
||||||
|
heading=title,
|
||||||
|
file=page["url"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return current_entries
|
||||||
|
|
||||||
|
def process_heading(self, heading):
|
||||||
|
return f"\n<b>{heading}</b>\n"
|
||||||
|
|
||||||
|
def process_nested_children(self, children, raw_content, block_type=None):
|
||||||
|
for child in children["results"]:
|
||||||
|
child_type = child.get("type")
|
||||||
|
if child_type == None:
|
||||||
|
continue
|
||||||
|
child_data = child[child_type]
|
||||||
|
if child_data.get("rich_text") and len(child_data["rich_text"]) > 0:
|
||||||
|
for text in child_data["rich_text"]:
|
||||||
|
raw_content += self.process_text(text, block_type)
|
||||||
|
if child_data.get("has_children", True):
|
||||||
|
return self.process_nested_children(self.get_block_children(child["id"]), raw_content, block_type)
|
||||||
|
|
||||||
|
return raw_content
|
||||||
|
|
||||||
|
def process_text(self, text, block_type=None):
|
||||||
|
text_type = text.get("type", None)
|
||||||
|
if text_type in self.unsupported_block_types:
|
||||||
|
return ""
|
||||||
|
if text.get("href", None):
|
||||||
|
return f"<a href='{text['href']}'>{text['plain_text']}</a>"
|
||||||
|
raw_text = text["plain_text"]
|
||||||
|
if text_type in self.display_block_block_types or block_type in self.display_block_block_types:
|
||||||
|
return f"\n{raw_text}\n"
|
||||||
|
return raw_text
|
||||||
|
|
||||||
|
def get_block_children(self, block_id):
|
||||||
|
return self.session.get(f"https://api.notion.com/v1/blocks/{block_id}/children").json()
|
||||||
|
|
||||||
|
def get_page(self, page_id):
|
||||||
|
return self.session.get(f"https://api.notion.com/v1/pages/{page_id}").json()
|
||||||
|
|
||||||
|
def get_page_children(self, page_id):
|
||||||
|
return self.session.get(f"https://api.notion.com/v1/blocks/{page_id}/children").json()
|
||||||
|
|
||||||
|
def get_page_content(self, page_id):
|
||||||
|
try:
|
||||||
|
page = self.get_page(page_id)
|
||||||
|
content = self.get_page_children(page_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting page {page_id}: {e}")
|
||||||
|
return None, None
|
||||||
|
properties = page["properties"]
|
||||||
|
title_field = "Title" if "Title" in properties else "title"
|
||||||
|
title = page["properties"][title_field]["title"][0]["text"]["content"]
|
||||||
|
return title, content
|
||||||
|
|
||||||
|
def update_entries_with_ids(self, current_entries, previous_entries):
|
||||||
|
# Identify, mark and merge any new entries with previous entries
|
||||||
|
with timer("Identify new or updated entries", logger):
|
||||||
|
if not previous_entries:
|
||||||
|
entries_with_ids = list(enumerate(current_entries))
|
||||||
|
else:
|
||||||
|
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||||
|
current_entries, previous_entries, key="compiled", logger=logger
|
||||||
|
)
|
||||||
|
|
||||||
|
with timer("Write Notion entries to JSONL file", logger):
|
||||||
|
# Process Each Entry from all Notion entries
|
||||||
|
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
||||||
|
jsonl_data = TextToJsonl.convert_text_maps_to_jsonl(entries)
|
||||||
|
|
||||||
|
# Compress JSONL formatted Data
|
||||||
|
if self.config.compressed_jsonl.suffix == ".gz":
|
||||||
|
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
|
||||||
|
elif self.config.compressed_jsonl.suffix == ".jsonl":
|
||||||
|
dump_jsonl(jsonl_data, self.config.compressed_jsonl)
|
||||||
|
|
||||||
|
return entries_with_ids
|
|
@ -62,7 +62,7 @@ class TextToJsonl(ABC):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mark_entries_for_update(
|
def mark_entries_for_update(
|
||||||
current_entries: List[Entry], previous_entries: List[Entry], key="compiled", logger=None
|
current_entries: List[Entry], previous_entries: List[Entry], key="compiled", logger: logging.Logger = None
|
||||||
) -> List[Tuple[int, Entry]]:
|
) -> List[Tuple[int, Entry]]:
|
||||||
# Hash all current and previous entries to identify new entries
|
# Hash all current and previous entries to identify new entries
|
||||||
with timer("Hash previous, current entries", logger):
|
with timer("Hash previous, current entries", logger):
|
||||||
|
@ -90,3 +90,8 @@ class TextToJsonl(ABC):
|
||||||
entries_with_ids = existing_entries_sorted + new_entries
|
entries_with_ids = existing_entries_sorted + new_entries
|
||||||
|
|
||||||
return entries_with_ids
|
return entries_with_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_text_maps_to_jsonl(entries: List[Entry]) -> str:
|
||||||
|
# Convert each entry to JSON and write to JSONL file
|
||||||
|
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
||||||
|
|
|
@ -28,6 +28,7 @@ from khoj.utils.rawconfig import (
|
||||||
TextContentConfig,
|
TextContentConfig,
|
||||||
ConversationProcessorConfig,
|
ConversationProcessorConfig,
|
||||||
GithubContentConfig,
|
GithubContentConfig,
|
||||||
|
NotionContentConfig,
|
||||||
)
|
)
|
||||||
from khoj.utils.state import SearchType
|
from khoj.utils.state import SearchType
|
||||||
from khoj.utils import state, constants
|
from khoj.utils import state, constants
|
||||||
|
@ -45,6 +46,11 @@ logger = logging.getLogger(__name__)
|
||||||
# If it's a demo instance, prevent updating any of the configuration.
|
# If it's a demo instance, prevent updating any of the configuration.
|
||||||
if not state.demo:
|
if not state.demo:
|
||||||
|
|
||||||
|
def _initialize_config():
|
||||||
|
if state.config is None:
|
||||||
|
state.config = FullConfig()
|
||||||
|
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
||||||
|
|
||||||
@api.get("/config/data", response_model=FullConfig)
|
@api.get("/config/data", response_model=FullConfig)
|
||||||
def get_config_data():
|
def get_config_data():
|
||||||
return state.config
|
return state.config
|
||||||
|
@ -59,9 +65,7 @@ if not state.demo:
|
||||||
|
|
||||||
@api.post("/config/data/content_type/github", status_code=200)
|
@api.post("/config/data/content_type/github", status_code=200)
|
||||||
async def set_content_config_github_data(updated_config: Union[GithubContentConfig, None]):
|
async def set_content_config_github_data(updated_config: Union[GithubContentConfig, None]):
|
||||||
if not state.config:
|
_initialize_config()
|
||||||
state.config = FullConfig()
|
|
||||||
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
|
||||||
|
|
||||||
if not state.config.content_type:
|
if not state.config.content_type:
|
||||||
state.config.content_type = ContentConfig(**{"github": updated_config})
|
state.config.content_type = ContentConfig(**{"github": updated_config})
|
||||||
|
@ -74,6 +78,21 @@ if not state.demo:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
@api.post("/config/data/content_type/notion", status_code=200)
|
||||||
|
async def set_content_config_notion_data(updated_config: Union[NotionContentConfig, None]):
|
||||||
|
_initialize_config()
|
||||||
|
|
||||||
|
if not state.config.content_type:
|
||||||
|
state.config.content_type = ContentConfig(**{"notion": updated_config})
|
||||||
|
else:
|
||||||
|
state.config.content_type.notion = updated_config
|
||||||
|
|
||||||
|
try:
|
||||||
|
save_config_to_file_updated_state()
|
||||||
|
return {"status": "ok"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
@api.post("/delete/config/data/content_type/{content_type}", status_code=200)
|
@api.post("/delete/config/data/content_type/{content_type}", status_code=200)
|
||||||
async def remove_content_config_data(content_type: str):
|
async def remove_content_config_data(content_type: str):
|
||||||
if not state.config or not state.config.content_type:
|
if not state.config or not state.config.content_type:
|
||||||
|
@ -84,6 +103,8 @@ if not state.demo:
|
||||||
|
|
||||||
if content_type == "github":
|
if content_type == "github":
|
||||||
state.model.github_search = None
|
state.model.github_search = None
|
||||||
|
elif content_type == "notion":
|
||||||
|
state.model.notion_search = None
|
||||||
elif content_type == "plugins":
|
elif content_type == "plugins":
|
||||||
state.model.plugin_search = None
|
state.model.plugin_search = None
|
||||||
elif content_type == "pdf":
|
elif content_type == "pdf":
|
||||||
|
@ -114,9 +135,7 @@ if not state.demo:
|
||||||
|
|
||||||
@api.post("/config/data/content_type/{content_type}", status_code=200)
|
@api.post("/config/data/content_type/{content_type}", status_code=200)
|
||||||
async def set_content_config_data(content_type: str, updated_config: Union[TextContentConfig, None]):
|
async def set_content_config_data(content_type: str, updated_config: Union[TextContentConfig, None]):
|
||||||
if not state.config:
|
_initialize_config()
|
||||||
state.config = FullConfig()
|
|
||||||
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
|
||||||
|
|
||||||
if not state.config.content_type:
|
if not state.config.content_type:
|
||||||
state.config.content_type = ContentConfig(**{content_type: updated_config})
|
state.config.content_type = ContentConfig(**{content_type: updated_config})
|
||||||
|
@ -131,9 +150,8 @@ if not state.demo:
|
||||||
|
|
||||||
@api.post("/config/data/processor/conversation", status_code=200)
|
@api.post("/config/data/processor/conversation", status_code=200)
|
||||||
async def set_processor_conversation_config_data(updated_config: Union[ConversationProcessorConfig, None]):
|
async def set_processor_conversation_config_data(updated_config: Union[ConversationProcessorConfig, None]):
|
||||||
if not state.config:
|
_initialize_config()
|
||||||
state.config = FullConfig()
|
|
||||||
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
|
||||||
state.config.processor = ProcessorConfig(conversation=updated_config)
|
state.config.processor = ProcessorConfig(conversation=updated_config)
|
||||||
state.processor_config = configure_processor(state.config.processor)
|
state.processor_config = configure_processor(state.config.processor)
|
||||||
try:
|
try:
|
||||||
|
@ -312,6 +330,20 @@ async def search(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if (t == SearchType.Notion or t == SearchType.All) and state.model.notion_search:
|
||||||
|
# query notion pages
|
||||||
|
search_futures += [
|
||||||
|
executor.submit(
|
||||||
|
text_search.query,
|
||||||
|
user_query,
|
||||||
|
state.model.notion_search,
|
||||||
|
question_embedding=encoded_asymmetric_query,
|
||||||
|
rank_results=r or False,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
dedupe=dedupe or True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
# Query across each requested content types in parallel
|
# Query across each requested content types in parallel
|
||||||
with timer("Query took", logger):
|
with timer("Query took", logger):
|
||||||
for search_future in concurrent.futures.as_completed(search_futures):
|
for search_future in concurrent.futures.as_completed(search_futures):
|
||||||
|
|
|
@ -63,6 +63,28 @@ if not state.demo:
|
||||||
"content_type_github_input.html", context={"request": request, "current_config": current_config}
|
"content_type_github_input.html", context={"request": request, "current_config": current_config}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@web_client.get("/config/content_type/notion", response_class=HTMLResponse)
|
||||||
|
def notion_config_page(request: Request):
|
||||||
|
default_copy = constants.default_config.copy()
|
||||||
|
default_notion = default_copy["content-type"]["notion"] # type: ignore
|
||||||
|
|
||||||
|
default_config = TextContentConfig(
|
||||||
|
compressed_jsonl=default_notion["compressed-jsonl"],
|
||||||
|
embeddings_file=default_notion["embeddings-file"],
|
||||||
|
)
|
||||||
|
|
||||||
|
current_config = (
|
||||||
|
state.config.content_type.notion
|
||||||
|
if state.config and state.config.content_type and state.config.content_type.notion
|
||||||
|
else default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
current_config = json.loads(current_config.json())
|
||||||
|
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
"content_type_notion_input.html", context={"request": request, "current_config": current_config}
|
||||||
|
)
|
||||||
|
|
||||||
@web_client.get("/config/content_type/{content_type}", response_class=HTMLResponse)
|
@web_client.get("/config/content_type/{content_type}", response_class=HTMLResponse)
|
||||||
def content_config_page(request: Request, content_type: str):
|
def content_config_page(request: Request, content_type: str):
|
||||||
if content_type not in VALID_TEXT_CONTENT_TYPES:
|
if content_type not in VALID_TEXT_CONTENT_TYPES:
|
||||||
|
|
|
@ -15,7 +15,7 @@ from khoj.utils import state
|
||||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
|
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
|
||||||
from khoj.utils.config import TextSearchModel
|
from khoj.utils.config import TextSearchModel
|
||||||
from khoj.utils.models import BaseEncoder
|
from khoj.utils.models import BaseEncoder
|
||||||
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
|
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry
|
||||||
from khoj.utils.jsonl import load_jsonl
|
from khoj.utils.jsonl import load_jsonl
|
||||||
|
|
||||||
|
|
||||||
|
@ -159,7 +159,11 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
|
||||||
{
|
{
|
||||||
"entry": entries[hit["corpus_id"]].raw,
|
"entry": entries[hit["corpus_id"]].raw,
|
||||||
"score": f"{hit.get('cross-score') or hit.get('score')}",
|
"score": f"{hit.get('cross-score') or hit.get('score')}",
|
||||||
"additional": {"file": entries[hit["corpus_id"]].file, "compiled": entries[hit["corpus_id"]].compiled},
|
"additional": {
|
||||||
|
"file": entries[hit["corpus_id"]].file,
|
||||||
|
"compiled": entries[hit["corpus_id"]].compiled,
|
||||||
|
"heading": entries[hit["corpus_id"]].heading,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
for hit in hits[0:count]
|
for hit in hits[0:count]
|
||||||
|
@ -168,7 +172,7 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
|
||||||
|
|
||||||
def setup(
|
def setup(
|
||||||
text_to_jsonl: Type[TextToJsonl],
|
text_to_jsonl: Type[TextToJsonl],
|
||||||
config: TextContentConfig,
|
config: TextConfigBase,
|
||||||
search_config: TextSearchConfig,
|
search_config: TextSearchConfig,
|
||||||
regenerate: bool,
|
regenerate: bool,
|
||||||
filters: List[BaseFilter] = [],
|
filters: List[BaseFilter] = [],
|
||||||
|
@ -186,7 +190,8 @@ def setup(
|
||||||
# Extract Updated Entries
|
# Extract Updated Entries
|
||||||
entries = extract_entries(config.compressed_jsonl)
|
entries = extract_entries(config.compressed_jsonl)
|
||||||
if is_none_or_empty(entries):
|
if is_none_or_empty(entries):
|
||||||
raise ValueError(f"No valid entries found in specified files: {config.input_files} or {config.input_filter}")
|
config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
|
||||||
|
raise ValueError(f"No valid entries found in specified files: {config_params}")
|
||||||
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
|
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
|
||||||
|
|
||||||
# Compute or Load Embeddings
|
# Compute or Load Embeddings
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations # to avoid quoting type hints
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Dict, List
|
from typing import TYPE_CHECKING, Dict, List, Union
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
|
@ -23,6 +23,7 @@ class SearchType(str, Enum):
|
||||||
Image = "image"
|
Image = "image"
|
||||||
Pdf = "pdf"
|
Pdf = "pdf"
|
||||||
Github = "github"
|
Github = "github"
|
||||||
|
Notion = "notion"
|
||||||
|
|
||||||
|
|
||||||
class ProcessorType(str, Enum):
|
class ProcessorType(str, Enum):
|
||||||
|
@ -58,12 +59,13 @@ class ImageSearchModel:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SearchModels:
|
class SearchModels:
|
||||||
org_search: TextSearchModel = None
|
org_search: Union[TextSearchModel, None] = None
|
||||||
markdown_search: TextSearchModel = None
|
markdown_search: Union[TextSearchModel, None] = None
|
||||||
pdf_search: TextSearchModel = None
|
pdf_search: Union[TextSearchModel, None] = None
|
||||||
image_search: ImageSearchModel = None
|
image_search: Union[ImageSearchModel, None] = None
|
||||||
github_search: TextSearchModel = None
|
github_search: Union[TextSearchModel, None] = None
|
||||||
plugin_search: Dict[str, TextSearchModel] = None
|
notion_search: Union[TextSearchModel, None] = None
|
||||||
|
plugin_search: Union[Dict[str, TextSearchModel], None] = None
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessorConfigModel:
|
class ConversationProcessorConfigModel:
|
||||||
|
@ -78,4 +80,4 @@ class ConversationProcessorConfigModel:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProcessorConfigModel:
|
class ProcessorConfigModel:
|
||||||
conversation: ConversationProcessorConfigModel = None
|
conversation: Union[ConversationProcessorConfigModel, None] = None
|
||||||
|
|
|
@ -41,6 +41,11 @@ default_config = {
|
||||||
"compressed-jsonl": "~/.khoj/content/github/github.jsonl.gz",
|
"compressed-jsonl": "~/.khoj/content/github/github.jsonl.gz",
|
||||||
"embeddings-file": "~/.khoj/content/github/github_embeddings.pt",
|
"embeddings-file": "~/.khoj/content/github/github_embeddings.pt",
|
||||||
},
|
},
|
||||||
|
"notion": {
|
||||||
|
"token": None,
|
||||||
|
"compressed-jsonl": "~/.khoj/content/notion/notion.jsonl.gz",
|
||||||
|
"embeddings-file": "~/.khoj/content/notion/notion_embeddings.pt",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"search-type": {
|
"search-type": {
|
||||||
"symmetric": {
|
"symmetric": {
|
||||||
|
|
|
@ -52,6 +52,10 @@ class GithubContentConfig(TextConfigBase):
|
||||||
repos: List[GithubRepoConfig]
|
repos: List[GithubRepoConfig]
|
||||||
|
|
||||||
|
|
||||||
|
class NotionContentConfig(TextConfigBase):
|
||||||
|
token: str
|
||||||
|
|
||||||
|
|
||||||
class ImageContentConfig(ConfigBase):
|
class ImageContentConfig(ConfigBase):
|
||||||
input_directories: Optional[List[Path]]
|
input_directories: Optional[List[Path]]
|
||||||
input_filter: Optional[List[str]]
|
input_filter: Optional[List[str]]
|
||||||
|
@ -77,6 +81,7 @@ class ContentConfig(ConfigBase):
|
||||||
pdf: Optional[TextContentConfig]
|
pdf: Optional[TextContentConfig]
|
||||||
github: Optional[GithubContentConfig]
|
github: Optional[GithubContentConfig]
|
||||||
plugins: Optional[Dict[str, TextContentConfig]]
|
plugins: Optional[Dict[str, TextContentConfig]]
|
||||||
|
notion: Optional[NotionContentConfig]
|
||||||
|
|
||||||
|
|
||||||
class TextSearchConfig(ConfigBase):
|
class TextSearchConfig(ConfigBase):
|
||||||
|
@ -148,4 +153,9 @@ class Entry:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, dictionary: dict):
|
def from_dict(cls, dictionary: dict):
|
||||||
return cls(raw=dictionary["raw"], compiled=dictionary["compiled"], file=dictionary.get("file", None))
|
return cls(
|
||||||
|
raw=dictionary["raw"],
|
||||||
|
compiled=dictionary["compiled"],
|
||||||
|
file=dictionary.get("file", None),
|
||||||
|
heading=dictionary.get("heading", None),
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in a new issue