mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Accept current changes to include issues in rendering flow
This commit is contained in:
commit
6edc32f2f4
19 changed files with 304 additions and 193 deletions
|
@ -73,15 +73,15 @@ khoj = "khoj.main:run"
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
test = [
|
test = [
|
||||||
"pytest >= 7.1.2",
|
"pytest >= 7.1.2",
|
||||||
|
"freezegun >= 1.2.0",
|
||||||
|
"factory-boy >= 3.2.1",
|
||||||
|
"trio >= 0.22.0",
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"khoj-assistant[test]",
|
"khoj-assistant[test]",
|
||||||
"mypy >= 1.0.1",
|
"mypy >= 1.0.1",
|
||||||
"black >= 23.1.0",
|
"black >= 23.1.0",
|
||||||
"pre-commit >= 3.0.4",
|
"pre-commit >= 3.0.4",
|
||||||
"freezegun >= 1.2.0",
|
|
||||||
"factory-boy==3.2.1",
|
|
||||||
"Faker==18.10.1",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.hatch.version]
|
[tool.hatch.version]
|
||||||
|
|
|
@ -598,9 +598,19 @@ CONFIG is json obtained from Khoj config API."
|
||||||
"Convert JSON-RESPONSE, QUERY from API to text entries."
|
"Convert JSON-RESPONSE, QUERY from API to text entries."
|
||||||
(thread-last json-response
|
(thread-last json-response
|
||||||
;; extract and render entries from API response
|
;; extract and render entries from API response
|
||||||
(mapcar (lambda (args) (format "%s\n\n" (cdr (assoc 'entry args)))))
|
(mapcar (lambda (json-response-item)
|
||||||
|
(thread-last
|
||||||
|
;; Extract pdf entry from each item in json response
|
||||||
|
(cdr (assoc 'entry json-response-item))
|
||||||
|
(format "%s\n\n")
|
||||||
|
;; Standardize results to 2nd level heading for consistent rendering
|
||||||
|
(replace-regexp-in-string "^\*+" "")
|
||||||
|
;; Standardize results to 2nd level heading for consistent rendering
|
||||||
|
(replace-regexp-in-string "^\#+" "")
|
||||||
|
;; Format entries as org entry string
|
||||||
|
(format "** %s"))))
|
||||||
;; Set query as heading in rendered results buffer
|
;; Set query as heading in rendered results buffer
|
||||||
(format "# Query: %s\n\n%s\n" query)
|
(format "* %s\n%s\n" query)
|
||||||
;; remove leading (, ) or SPC from extracted entries string
|
;; remove leading (, ) or SPC from extracted entries string
|
||||||
(replace-regexp-in-string "^[\(\) ]" "")
|
(replace-regexp-in-string "^[\(\) ]" "")
|
||||||
;; remove trailing (, ) or SPC from extracted entries string
|
;; remove trailing (, ) or SPC from extracted entries string
|
||||||
|
@ -674,7 +684,8 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE."
|
||||||
((equal content-type "ledger") (khoj--extract-entries-as-ledger json-response query))
|
((equal content-type "ledger") (khoj--extract-entries-as-ledger json-response query))
|
||||||
((equal content-type "image") (khoj--extract-entries-as-images json-response query))
|
((equal content-type "image") (khoj--extract-entries-as-images json-response query))
|
||||||
(t (khoj--extract-entries json-response query))))
|
(t (khoj--extract-entries json-response query))))
|
||||||
(cond ((or (equal content-type "pdf")
|
(cond ((or (equal content-type "all")
|
||||||
|
(equal content-type "pdf")
|
||||||
(equal content-type "org"))
|
(equal content-type "org"))
|
||||||
(progn (visual-line-mode)
|
(progn (visual-line-mode)
|
||||||
(org-mode)
|
(org-mode)
|
||||||
|
@ -1003,7 +1014,7 @@ Paragraph only starts at first text after blank line."
|
||||||
;; set content type to: last used > based on current buffer > default type
|
;; set content type to: last used > based on current buffer > default type
|
||||||
:init-value (lambda (obj) (oset obj value (format "--content-type=%s" (or khoj--content-type (khoj--buffer-name-to-content-type (buffer-name))))))
|
:init-value (lambda (obj) (oset obj value (format "--content-type=%s" (or khoj--content-type (khoj--buffer-name-to-content-type (buffer-name))))))
|
||||||
;; dynamically set choices to content types enabled on khoj backend
|
;; dynamically set choices to content types enabled on khoj backend
|
||||||
:choices (or (ignore-errors (mapcar #'symbol-name (khoj--get-enabled-content-types))) '("org" "markdown" "pdf" "ledger" "music" "image")))
|
:choices (or (ignore-errors (mapcar #'symbol-name (khoj--get-enabled-content-types))) '("all" "org" "markdown" "pdf" "ledger" "music" "image")))
|
||||||
|
|
||||||
(transient-define-suffix khoj--search-command (&optional args)
|
(transient-define-suffix khoj--search-command (&optional args)
|
||||||
(interactive (list (transient-args transient-current-command)))
|
(interactive (list (transient-args transient-current-command)))
|
||||||
|
|
|
@ -3,6 +3,7 @@ import sys
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
|
@ -36,7 +37,7 @@ def configure_server(args, required=False):
|
||||||
logger.error(f"Exiting as Khoj is not configured.\nConfigure it via GUI or by editing {state.config_file}.")
|
logger.error(f"Exiting as Khoj is not configured.\nConfigure it via GUI or by editing {state.config_file}.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
else:
|
else:
|
||||||
logger.warn(
|
logger.warning(
|
||||||
f"Khoj is not configured.\nConfigure it via khoj GUI, plugins or by editing {state.config_file}."
|
f"Khoj is not configured.\nConfigure it via khoj GUI, plugins or by editing {state.config_file}."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
@ -78,16 +79,20 @@ def configure_search_types(config: FullConfig):
|
||||||
core_search_types = {e.name: e.value for e in SearchType}
|
core_search_types = {e.name: e.value for e in SearchType}
|
||||||
# Extract configured plugin search types
|
# Extract configured plugin search types
|
||||||
plugin_search_types = {}
|
plugin_search_types = {}
|
||||||
if config.content_type.plugins:
|
if config.content_type and config.content_type.plugins:
|
||||||
plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()}
|
plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()}
|
||||||
|
|
||||||
# Dynamically generate search type enum by merging core search types with configured plugin search types
|
# Dynamically generate search type enum by merging core search types with configured plugin search types
|
||||||
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
||||||
|
|
||||||
|
|
||||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: state.SearchType = None):
|
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None):
|
||||||
|
if config is None or config.content_type is None or config.search_type is None:
|
||||||
|
logger.warning("🚨 No Content or Search type is configured.")
|
||||||
|
return
|
||||||
|
|
||||||
# Initialize Org Notes Search
|
# Initialize Org Notes Search
|
||||||
if (t == state.SearchType.Org or t == None) and config.content_type.org:
|
if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric:
|
||||||
logger.info("🦄 Setting up search for orgmode notes")
|
logger.info("🦄 Setting up search for orgmode notes")
|
||||||
# Extract Entries, Generate Notes Embeddings
|
# Extract Entries, Generate Notes Embeddings
|
||||||
model.org_search = text_search.setup(
|
model.org_search = text_search.setup(
|
||||||
|
@ -99,7 +104,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Org Music Search
|
# Initialize Org Music Search
|
||||||
if (t == state.SearchType.Music or t == None) and config.content_type.music:
|
if (t == state.SearchType.Music or t == None) and config.content_type.music and config.search_type.asymmetric:
|
||||||
logger.info("🎺 Setting up search for org-music")
|
logger.info("🎺 Setting up search for org-music")
|
||||||
# Extract Entries, Generate Music Embeddings
|
# Extract Entries, Generate Music Embeddings
|
||||||
model.music_search = text_search.setup(
|
model.music_search = text_search.setup(
|
||||||
|
@ -111,7 +116,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Markdown Search
|
# Initialize Markdown Search
|
||||||
if (t == state.SearchType.Markdown or t == None) and config.content_type.markdown:
|
if (t == state.SearchType.Markdown or t == None) and config.content_type.markdown and config.search_type.asymmetric:
|
||||||
logger.info("💎 Setting up search for markdown notes")
|
logger.info("💎 Setting up search for markdown notes")
|
||||||
# Extract Entries, Generate Markdown Embeddings
|
# Extract Entries, Generate Markdown Embeddings
|
||||||
model.markdown_search = text_search.setup(
|
model.markdown_search = text_search.setup(
|
||||||
|
@ -123,7 +128,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Ledger Search
|
# Initialize Ledger Search
|
||||||
if (t == state.SearchType.Ledger or t == None) and config.content_type.ledger:
|
if (t == state.SearchType.Ledger or t == None) and config.content_type.ledger and config.search_type.symmetric:
|
||||||
logger.info("💸 Setting up search for ledger")
|
logger.info("💸 Setting up search for ledger")
|
||||||
# Extract Entries, Generate Ledger Embeddings
|
# Extract Entries, Generate Ledger Embeddings
|
||||||
model.ledger_search = text_search.setup(
|
model.ledger_search = text_search.setup(
|
||||||
|
@ -135,7 +140,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize PDF Search
|
# Initialize PDF Search
|
||||||
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf:
|
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric:
|
||||||
logger.info("🖨️ Setting up search for pdf")
|
logger.info("🖨️ Setting up search for pdf")
|
||||||
# Extract Entries, Generate PDF Embeddings
|
# Extract Entries, Generate PDF Embeddings
|
||||||
model.pdf_search = text_search.setup(
|
model.pdf_search = text_search.setup(
|
||||||
|
@ -147,14 +152,14 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Image Search
|
# Initialize Image Search
|
||||||
if (t == state.SearchType.Image or t == None) and config.content_type.image:
|
if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image:
|
||||||
logger.info("🌄 Setting up search for images")
|
logger.info("🌄 Setting up search for images")
|
||||||
# Extract Entries, Generate Image Embeddings
|
# Extract Entries, Generate Image Embeddings
|
||||||
model.image_search = image_search.setup(
|
model.image_search = image_search.setup(
|
||||||
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
|
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
|
||||||
)
|
)
|
||||||
|
|
||||||
if (t == state.SearchType.Github or t == None) and config.content_type.github:
|
if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric:
|
||||||
logger.info("🐙 Setting up search for github")
|
logger.info("🐙 Setting up search for github")
|
||||||
# Extract Entries, Generate Github Embeddings
|
# Extract Entries, Generate Github Embeddings
|
||||||
model.github_search = text_search.setup(
|
model.github_search = text_search.setup(
|
||||||
|
|
|
@ -14,11 +14,13 @@
|
||||||
<script>
|
<script>
|
||||||
function render_image(item) {
|
function render_image(item) {
|
||||||
return `
|
return `
|
||||||
|
<div class="results-image">
|
||||||
<a href="${item.entry}" class="image-link">
|
<a href="${item.entry}" class="image-link">
|
||||||
<img id=${item.score} src="${item.entry}?${Math.random()}"
|
<img id=${item.score} src="${item.entry}?${Math.random()}"
|
||||||
title="Effective Score: ${item.score}, Meta: ${item.additional.metadata_score}, Image: ${item.additional.image_score}"
|
title="Effective Score: ${item.score}, Meta: ${item.additional.metadata_score}, Image: ${item.additional.image_score}"
|
||||||
class="image">
|
class="image">
|
||||||
</a>`
|
</a>
|
||||||
|
</div>`;
|
||||||
}
|
}
|
||||||
|
|
||||||
function render_org(query, data, classPrefix="") {
|
function render_org(query, data, classPrefix="") {
|
||||||
|
@ -28,33 +30,33 @@
|
||||||
var orgParser = new Org.Parser();
|
var orgParser = new Org.Parser();
|
||||||
var orgDocument = orgParser.parse(orgCode);
|
var orgDocument = orgParser.parse(orgCode);
|
||||||
var orgHTMLDocument = orgDocument.convert(Org.ConverterHTML, { htmlClassPrefix: classPrefix });
|
var orgHTMLDocument = orgDocument.convert(Org.ConverterHTML, { htmlClassPrefix: classPrefix });
|
||||||
return orgHTMLDocument.toString();
|
return `<div class="results-org">` + orgHTMLDocument.toString() + `</div>`;
|
||||||
}
|
}
|
||||||
|
|
||||||
function render_markdown(query, data) {
|
function render_markdown(query, data) {
|
||||||
var md = window.markdownit();
|
var md = window.markdownit();
|
||||||
return md.render(data.map(function (item) {
|
return data.map(function (item) {
|
||||||
if (item.additional.file.startsWith("http")) {
|
if (item.additional.file.startsWith("http")) {
|
||||||
lines = item.entry.split("\n");
|
lines = item.entry.split("\n");
|
||||||
return `${lines[0]}\t[*](${item.additional.file})\n${lines.slice(1).join("\n")}`;
|
return md.render(`${lines[0]}\t[*](${item.additional.file})\n${lines.slice(1).join("\n")}`);
|
||||||
}
|
}
|
||||||
return `${item.entry}`;
|
return `<div class="results-markdown">` + md.render(`${item.entry}`) + `</div>`;
|
||||||
}).join("\n"));
|
}).join("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
function render_ledger(query, data) {
|
function render_ledger(query, data) {
|
||||||
return `<div id="results-ledger">` + data.map(function (item) {
|
return data.map(function (item) {
|
||||||
return `<p>${item.entry}</p>`
|
return `<div class="results-ledger">` + `<p>${item.entry}</p>` + `</div>`;
|
||||||
}).join("\n") + `</div>`;
|
}).join("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
function render_pdf(query, data) {
|
function render_pdf(query, data) {
|
||||||
return `<div id="results-pdf">` + data.map(function (item) {
|
return data.map(function (item) {
|
||||||
let compiled_lines = item.additional.compiled.split("\n");
|
let compiled_lines = item.additional.compiled.split("\n");
|
||||||
let filename = compiled_lines.shift();
|
let filename = compiled_lines.shift();
|
||||||
let text_match = compiled_lines.join("\n")
|
let text_match = compiled_lines.join("\n")
|
||||||
return `<h2>${filename}</h2>\n<p>${text_match}</p>`
|
return `<div class="results-pdf">` + `<h2>${filename}</h2>\n<p>${text_match}</p>` + `</div>`;
|
||||||
}).join("\n") + `</div>`;
|
}).join("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
function render_mutliple(query, data, type) {
|
function render_mutliple(query, data, type) {
|
||||||
|
@ -83,26 +85,26 @@
|
||||||
return html;
|
return html;
|
||||||
}
|
}
|
||||||
|
|
||||||
function render_json(data, query, type) {
|
function render_results(data, query, type) {
|
||||||
|
let results = "";
|
||||||
if (type === "markdown") {
|
if (type === "markdown") {
|
||||||
return render_markdown(query, data);
|
results = render_markdown(query, data);
|
||||||
} else if (type === "org") {
|
} else if (type === "org") {
|
||||||
return render_org(query, data);
|
results = render_org(query, data, "org-");
|
||||||
} else if (type === "music") {
|
} else if (type === "music") {
|
||||||
return render_org(query, data, "music-");
|
results = render_org(query, data, "music-");
|
||||||
} else if (type === "image") {
|
} else if (type === "image") {
|
||||||
return data.map(render_image).join('');
|
results = data.map(render_image).join('');
|
||||||
} else if (type === "ledger") {
|
} else if (type === "ledger") {
|
||||||
return render_ledger(query, data);
|
results = render_ledger(query, data);
|
||||||
} else if (type === "pdf") {
|
} else if (type === "pdf") {
|
||||||
return render_pdf(query, data);
|
results = render_pdf(query, data);
|
||||||
} else if (type == "github") {
|
} else if (type === "github" || type === "all") {
|
||||||
return render_mutliple(query, data, type);
|
results = render_mutliple(query, data, type);
|
||||||
} else {
|
} else {
|
||||||
return `<div id="results-plugin">`
|
results = data.map((item) => `<div class="results-plugin">` + `<p>${item.entry}</p>` + `</div>`).join("\n")
|
||||||
+ data.map((item) => `<p>${item.entry}</p>`).join("\n")
|
|
||||||
+ `</div>`;
|
|
||||||
}
|
}
|
||||||
|
return `<div id="results-${type}">${results}</div>`;
|
||||||
}
|
}
|
||||||
|
|
||||||
function search(rerank=false) {
|
function search(rerank=false) {
|
||||||
|
@ -120,20 +122,13 @@
|
||||||
if (rerank)
|
if (rerank)
|
||||||
setQueryFieldInUrl(query);
|
setQueryFieldInUrl(query);
|
||||||
|
|
||||||
// Generate Backend API URL to execute Search
|
|
||||||
url = type === "image"
|
|
||||||
? `/api/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&client=web`
|
|
||||||
: `/api/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&r=${rerank}&client=web`;
|
|
||||||
|
|
||||||
// Execute Search and Render Results
|
// Execute Search and Render Results
|
||||||
|
url = createRequestUrl(query, type, results_count, rerank);
|
||||||
fetch(url)
|
fetch(url)
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
console.log(data);
|
console.log(data);
|
||||||
document.getElementById("results").innerHTML =
|
document.getElementById("results").innerHTML = render_results(data, query, type);
|
||||||
`<div id=results-${type}>`
|
|
||||||
+ render_json(data, query, type)
|
|
||||||
+ `</div>`;
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,7 +139,7 @@
|
||||||
.then(data => {
|
.then(data => {
|
||||||
console.log(data);
|
console.log(data);
|
||||||
document.getElementById("results").innerHTML =
|
document.getElementById("results").innerHTML =
|
||||||
render_json(data);
|
render_results(data);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,6 +175,18 @@
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function createRequestUrl(query, type, results_count, rerank) {
|
||||||
|
// Generate Backend API URL to execute Search
|
||||||
|
let url = `/api/search?q=${encodeURIComponent(query)}&n=${results_count}&client=web`;
|
||||||
|
// If type is not 'all', append type to URL
|
||||||
|
if (type !== 'all')
|
||||||
|
url += `&t=${type}`;
|
||||||
|
// Rerank is only supported by text types
|
||||||
|
if (type !== "image")
|
||||||
|
url += `&r=${rerank}`;
|
||||||
|
return url;
|
||||||
|
}
|
||||||
|
|
||||||
function setTypeFieldInUrl(type) {
|
function setTypeFieldInUrl(type) {
|
||||||
var url = new URL(window.location.href);
|
var url = new URL(window.location.href);
|
||||||
url.searchParams.set("t", type.value);
|
url.searchParams.set("t", type.value);
|
||||||
|
@ -309,7 +316,7 @@
|
||||||
margin: 0px;
|
margin: 0px;
|
||||||
line-height: 20px;
|
line-height: 20px;
|
||||||
}
|
}
|
||||||
#results-image {
|
.results-image {
|
||||||
display: grid;
|
display: grid;
|
||||||
grid-template-columns: repeat(3, 1fr);
|
grid-template-columns: repeat(3, 1fr);
|
||||||
}
|
}
|
||||||
|
@ -324,27 +331,28 @@
|
||||||
#json {
|
#json {
|
||||||
white-space: pre-wrap;
|
white-space: pre-wrap;
|
||||||
}
|
}
|
||||||
#results-pdf,
|
.results-pdf,
|
||||||
#results-plugin,
|
.results-plugin,
|
||||||
#results-ledger {
|
.results-ledger {
|
||||||
text-align: left;
|
text-align: left;
|
||||||
white-space: pre-line;
|
white-space: pre-line;
|
||||||
}
|
}
|
||||||
#results-markdown, #results-github {
|
.results-markdown,
|
||||||
|
.results-github {
|
||||||
text-align: left;
|
text-align: left;
|
||||||
}
|
}
|
||||||
#results-music,
|
.results-music,
|
||||||
#results-org {
|
.results-org {
|
||||||
text-align: left;
|
text-align: left;
|
||||||
white-space: pre-line;
|
white-space: pre-line;
|
||||||
}
|
}
|
||||||
#results-music h3,
|
.results-music h3,
|
||||||
#results-org h3 {
|
.results-org h3 {
|
||||||
margin: 20px 0 0 0;
|
margin: 20px 0 0 0;
|
||||||
font-size: larger;
|
font-size: larger;
|
||||||
}
|
}
|
||||||
span.music-task-status,
|
span.music-task-status,
|
||||||
span.task-status {
|
span.org-task-status {
|
||||||
color: white;
|
color: white;
|
||||||
padding: 3.5px 3.5px 0;
|
padding: 3.5px 3.5px 0;
|
||||||
margin-right: 5px;
|
margin-right: 5px;
|
||||||
|
@ -353,15 +361,15 @@
|
||||||
font-size: medium;
|
font-size: medium;
|
||||||
}
|
}
|
||||||
span.music-task-status.todo,
|
span.music-task-status.todo,
|
||||||
span.task-status.todo {
|
span.org-task-status.todo {
|
||||||
background-color: #3b82f6
|
background-color: #3b82f6
|
||||||
}
|
}
|
||||||
span.music-task-status.done,
|
span.music-task-status.done,
|
||||||
span.task-status.done {
|
span.org-task-status.done {
|
||||||
background-color: #22c55e;
|
background-color: #22c55e;
|
||||||
}
|
}
|
||||||
span.music-task-tag,
|
span.music-task-tag,
|
||||||
span.task-tag {
|
span.org-task-tag {
|
||||||
color: white;
|
color: white;
|
||||||
padding: 3.5px 3.5px 0;
|
padding: 3.5px 3.5px 0;
|
||||||
margin-right: 5px;
|
margin-right: 5px;
|
||||||
|
|
|
@ -113,7 +113,7 @@ def extract_questions(text, model="text-davinci-003", conversation_log={}, api_k
|
||||||
.replace("', '", '", "')
|
.replace("', '", '", "')
|
||||||
)
|
)
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
logger.warn(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
logger.warning(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||||
questions = [text]
|
questions = [text]
|
||||||
logger.debug(f"Extracted Questions by GPT: {questions}")
|
logger.debug(f"Extracted Questions by GPT: {questions}")
|
||||||
return questions
|
return questions
|
||||||
|
|
|
@ -92,7 +92,7 @@ class MarkdownToJsonl(TextToJsonl):
|
||||||
}
|
}
|
||||||
|
|
||||||
if any(files_with_non_markdown_extensions):
|
if any(files_with_non_markdown_extensions):
|
||||||
logger.warn(
|
logger.warning(
|
||||||
f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}"
|
f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -88,7 +88,7 @@ class OrgToJsonl(TextToJsonl):
|
||||||
|
|
||||||
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
|
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
|
||||||
if any(files_with_non_org_extensions):
|
if any(files_with_non_org_extensions):
|
||||||
logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
|
logger.warning(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
|
||||||
|
|
||||||
logger.debug(f"Processing files: {all_org_files}")
|
logger.debug(f"Processing files: {all_org_files}")
|
||||||
|
|
||||||
|
|
|
@ -83,7 +83,9 @@ class PdfToJsonl(TextToJsonl):
|
||||||
files_with_non_pdf_extensions = {pdf_file for pdf_file in all_pdf_files if not pdf_file.endswith(".pdf")}
|
files_with_non_pdf_extensions = {pdf_file for pdf_file in all_pdf_files if not pdf_file.endswith(".pdf")}
|
||||||
|
|
||||||
if any(files_with_non_pdf_extensions):
|
if any(files_with_non_pdf_extensions):
|
||||||
logger.warn(f"[Warning] There maybe non pdf-mode files in the input set: {files_with_non_pdf_extensions}")
|
logger.warning(
|
||||||
|
f"[Warning] There maybe non pdf-mode files in the input set: {files_with_non_pdf_extensions}"
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f"Processing files: {all_pdf_files}")
|
logger.debug(f"Processing files: {all_pdf_files}")
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
|
import concurrent.futures
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
import yaml
|
import yaml
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -8,12 +10,17 @@ from typing import List, Optional, Union
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
from sentence_transformers import util
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_processor, configure_search
|
from khoj.configure import configure_processor, configure_search
|
||||||
from khoj.processor.conversation.gpt import converse, extract_questions
|
from khoj.processor.conversation.gpt import converse, extract_questions
|
||||||
from khoj.processor.conversation.utils import message_to_log, message_to_prompt
|
from khoj.processor.conversation.utils import message_to_log, message_to_prompt
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
|
from khoj.utils.config import TextSearchModel
|
||||||
from khoj.utils.helpers import log_telemetry, timer
|
from khoj.utils.helpers import log_telemetry, timer
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
ContentConfig,
|
ContentConfig,
|
||||||
|
@ -58,6 +65,7 @@ def get_config_types():
|
||||||
and getattr(state.model, f"{search_type.value}_search") is not None
|
and getattr(state.model, f"{search_type.value}_search") is not None
|
||||||
)
|
)
|
||||||
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
|
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
|
||||||
|
or search_type == SearchType.All
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -125,24 +133,31 @@ async def set_processor_conversation_config_data(updated_config: ConversationPro
|
||||||
|
|
||||||
|
|
||||||
@api.get("/search", response_model=List[SearchResponse])
|
@api.get("/search", response_model=List[SearchResponse])
|
||||||
def search(
|
async def search(
|
||||||
q: str,
|
q: str,
|
||||||
n: Optional[int] = 5,
|
n: Optional[int] = 5,
|
||||||
t: Optional[SearchType] = None,
|
t: Optional[SearchType] = SearchType.All,
|
||||||
r: Optional[bool] = False,
|
r: Optional[bool] = False,
|
||||||
score_threshold: Optional[Union[float, None]] = None,
|
score_threshold: Optional[Union[float, None]] = None,
|
||||||
dedupe: Optional[bool] = True,
|
dedupe: Optional[bool] = True,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Run validation checks
|
||||||
results: List[SearchResponse] = []
|
results: List[SearchResponse] = []
|
||||||
if q is None or q == "":
|
if q is None or q == "":
|
||||||
logger.warn(f"No query param (q) passed in API call to initiate search")
|
logger.warning(f"No query param (q) passed in API call to initiate search")
|
||||||
|
return results
|
||||||
|
if not state.model or not any(state.model.__dict__.values()):
|
||||||
|
logger.warning(f"No search models loaded. Configure a search model before initiating search")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
# initialize variables
|
# initialize variables
|
||||||
user_query = q.strip()
|
user_query = q.strip()
|
||||||
results_count = n
|
results_count = n
|
||||||
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
||||||
|
search_futures: List[concurrent.futures.Future] = []
|
||||||
|
|
||||||
# return cached results, if available
|
# return cached results, if available
|
||||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
||||||
|
@ -150,105 +165,146 @@ def search(
|
||||||
logger.debug(f"Return response from query cache")
|
logger.debug(f"Return response from query cache")
|
||||||
return state.query_cache[query_cache_key]
|
return state.query_cache[query_cache_key]
|
||||||
|
|
||||||
if (t == SearchType.Org or t == None) and state.model.org_search:
|
# Encode query with filter terms removed
|
||||||
# query org-mode notes
|
defiltered_query = user_query
|
||||||
|
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
||||||
|
defiltered_query = filter.defilter(user_query)
|
||||||
|
|
||||||
|
encoded_asymmetric_query = None
|
||||||
|
if t == SearchType.All or (t != SearchType.Ledger and t != SearchType.Image):
|
||||||
|
text_search_models: List[TextSearchModel] = [
|
||||||
|
model
|
||||||
|
for model_name, model in state.model.__dict__.items()
|
||||||
|
if isinstance(model, TextSearchModel) and model_name != "ledger_search"
|
||||||
|
]
|
||||||
|
if text_search_models:
|
||||||
|
with timer("Encoding query took", logger=logger):
|
||||||
|
encoded_asymmetric_query = util.normalize_embeddings(
|
||||||
|
text_search_models[0].bi_encoder.encode(
|
||||||
|
[defiltered_query],
|
||||||
|
convert_to_tensor=True,
|
||||||
|
device=state.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search:
|
||||||
|
# query org-mode notes
|
||||||
|
search_futures += [
|
||||||
|
executor.submit(
|
||||||
|
text_search.query,
|
||||||
|
user_query,
|
||||||
|
state.model.org_search,
|
||||||
|
question_embedding=encoded_asymmetric_query,
|
||||||
|
rank_results=r or False,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
dedupe=dedupe or True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if (t == SearchType.Markdown or t == SearchType.All) and state.model.markdown_search:
|
||||||
|
# query markdown notes
|
||||||
|
search_futures += [
|
||||||
|
executor.submit(
|
||||||
|
text_search.query,
|
||||||
|
user_query,
|
||||||
|
state.model.markdown_search,
|
||||||
|
question_embedding=encoded_asymmetric_query,
|
||||||
|
rank_results=r or False,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
dedupe=dedupe or True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if (t == SearchType.Pdf or t == SearchType.All) and state.model.pdf_search:
|
||||||
|
# query pdf files
|
||||||
|
search_futures += [
|
||||||
|
executor.submit(
|
||||||
|
text_search.query,
|
||||||
|
user_query,
|
||||||
|
state.model.pdf_search,
|
||||||
|
question_embedding=encoded_asymmetric_query,
|
||||||
|
rank_results=r or False,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
dedupe=dedupe or True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if (t == SearchType.Ledger) and state.model.ledger_search:
|
||||||
|
# query transactions
|
||||||
|
search_futures += [
|
||||||
|
executor.submit(
|
||||||
|
text_search.query,
|
||||||
|
user_query,
|
||||||
|
state.model.ledger_search,
|
||||||
|
rank_results=r or False,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
dedupe=dedupe or True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if (t == SearchType.Music or t == SearchType.All) and state.model.music_search:
|
||||||
|
# query music library
|
||||||
|
search_futures += [
|
||||||
|
executor.submit(
|
||||||
|
text_search.query,
|
||||||
|
user_query,
|
||||||
|
state.model.music_search,
|
||||||
|
question_embedding=encoded_asymmetric_query,
|
||||||
|
rank_results=r or False,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
dedupe=dedupe or True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if (t == SearchType.Image) and state.model.image_search:
|
||||||
|
# query images
|
||||||
|
search_futures += [
|
||||||
|
executor.submit(
|
||||||
|
image_search.query,
|
||||||
|
user_query,
|
||||||
|
results_count,
|
||||||
|
state.model.image_search,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if (t == SearchType.All or t in SearchType) and state.model.plugin_search:
|
||||||
|
# query specified plugin type
|
||||||
|
search_futures += [
|
||||||
|
executor.submit(
|
||||||
|
text_search.query,
|
||||||
|
user_query,
|
||||||
|
# Get plugin search model for specified search type, or the first one if none specified
|
||||||
|
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
||||||
|
question_embedding=encoded_asymmetric_query,
|
||||||
|
rank_results=r or False,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
dedupe=dedupe or True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Query across each requested content types in parallel
|
||||||
with timer("Query took", logger):
|
with timer("Query took", logger):
|
||||||
hits, entries = text_search.query(
|
for search_future in concurrent.futures.as_completed(search_futures):
|
||||||
user_query, state.model.org_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
if t == SearchType.Image:
|
||||||
)
|
hits = await search_future.result()
|
||||||
|
output_directory = constants.web_directory / "images"
|
||||||
|
# Collate results
|
||||||
|
results += image_search.collate_results(
|
||||||
|
hits,
|
||||||
|
image_names=state.model.image_search.image_names,
|
||||||
|
output_directory=output_directory,
|
||||||
|
image_files_url="/static/images",
|
||||||
|
count=results_count or 5,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hits, entries = await search_future.result()
|
||||||
|
# Collate results
|
||||||
|
results += text_search.collate_results(hits, entries, results_count or 5)
|
||||||
|
|
||||||
# collate and return results
|
# Sort results across all content types
|
||||||
with timer("Collating results took", logger):
|
results.sort(key=lambda x: float(x.score), reverse=True)
|
||||||
results = text_search.collate_results(hits, entries, results_count)
|
|
||||||
|
|
||||||
elif (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
|
||||||
# query markdown files
|
|
||||||
with timer("Query took", logger):
|
|
||||||
hits, entries = text_search.query(
|
|
||||||
user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
|
||||||
)
|
|
||||||
|
|
||||||
# collate and return results
|
|
||||||
with timer("Collating results took", logger):
|
|
||||||
results = text_search.collate_results(hits, entries, results_count)
|
|
||||||
|
|
||||||
elif (t == SearchType.Pdf or t == None) and state.model.pdf_search:
|
|
||||||
# query pdf files
|
|
||||||
with timer("Query took", logger):
|
|
||||||
hits, entries = text_search.query(
|
|
||||||
user_query, state.model.pdf_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
|
||||||
)
|
|
||||||
|
|
||||||
# collate and return results
|
|
||||||
with timer("Collating results took", logger):
|
|
||||||
results = text_search.collate_results(hits, entries, results_count)
|
|
||||||
|
|
||||||
elif (t == SearchType.Github or t == None) and state.model.github_search:
|
|
||||||
# query github embeddings
|
|
||||||
with timer("Query took", logger):
|
|
||||||
hits, entries = text_search.query(
|
|
||||||
user_query, state.model.github_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
|
||||||
)
|
|
||||||
|
|
||||||
# collate and return results
|
|
||||||
with timer("Collating results took", logger):
|
|
||||||
results = text_search.collate_results(hits, entries, results_count)
|
|
||||||
|
|
||||||
elif (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
|
||||||
# query transactions
|
|
||||||
with timer("Query took", logger):
|
|
||||||
hits, entries = text_search.query(
|
|
||||||
user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
|
||||||
)
|
|
||||||
|
|
||||||
# collate and return results
|
|
||||||
with timer("Collating results took", logger):
|
|
||||||
results = text_search.collate_results(hits, entries, results_count)
|
|
||||||
|
|
||||||
elif (t == SearchType.Music or t == None) and state.model.music_search:
|
|
||||||
# query music library
|
|
||||||
with timer("Query took", logger):
|
|
||||||
hits, entries = text_search.query(
|
|
||||||
user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
|
||||||
)
|
|
||||||
|
|
||||||
# collate and return results
|
|
||||||
with timer("Collating results took", logger):
|
|
||||||
results = text_search.collate_results(hits, entries, results_count)
|
|
||||||
|
|
||||||
elif (t == SearchType.Image or t == None) and state.model.image_search:
|
|
||||||
# query images
|
|
||||||
with timer("Query took", logger):
|
|
||||||
hits = image_search.query(
|
|
||||||
user_query, results_count, state.model.image_search, score_threshold=score_threshold
|
|
||||||
)
|
|
||||||
output_directory = constants.web_directory / "images"
|
|
||||||
|
|
||||||
# collate and return results
|
|
||||||
with timer("Collating results took", logger):
|
|
||||||
results = image_search.collate_results(
|
|
||||||
hits,
|
|
||||||
image_names=state.model.image_search.image_names,
|
|
||||||
output_directory=output_directory,
|
|
||||||
image_files_url="/static/images",
|
|
||||||
count=results_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif (t in SearchType or t == None) and state.model.plugin_search:
|
|
||||||
# query specified plugin type
|
|
||||||
with timer("Query took", logger):
|
|
||||||
hits, entries = text_search.query(
|
|
||||||
user_query,
|
|
||||||
# Get plugin search model for specified search type, or the first one if none specified
|
|
||||||
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
|
||||||
rank_results=r,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
dedupe=dedupe,
|
|
||||||
)
|
|
||||||
|
|
||||||
# collate and return results
|
|
||||||
with timer("Collating results took", logger):
|
|
||||||
results = text_search.collate_results(hits, entries, results_count)
|
|
||||||
|
|
||||||
# Cache results
|
# Cache results
|
||||||
state.query_cache[query_cache_key] = results
|
state.query_cache[query_cache_key] = results
|
||||||
|
@ -260,6 +316,9 @@ def search(
|
||||||
]
|
]
|
||||||
state.previous_query = user_query
|
state.previous_query = user_query
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@ -267,7 +326,7 @@ def search(
|
||||||
def update(t: Optional[SearchType] = None, force: Optional[bool] = False, client: Optional[str] = None):
|
def update(t: Optional[SearchType] = None, force: Optional[bool] = False, client: Optional[str] = None):
|
||||||
try:
|
try:
|
||||||
state.search_index_lock.acquire()
|
state.search_index_lock.acquire()
|
||||||
state.model = configure_search(state.model, state.config, regenerate=force, t=t)
|
state.model = configure_search(state.model, state.config, regenerate=force or False, t=t)
|
||||||
state.search_index_lock.release()
|
state.search_index_lock.release()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
|
|
|
@ -18,3 +18,7 @@ class BaseFilter(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
|
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def defilter(self, query: str) -> str:
|
||||||
|
...
|
||||||
|
|
|
@ -49,6 +49,12 @@ class DateFilter(BaseFilter):
|
||||||
"Check if query contains date filters"
|
"Check if query contains date filters"
|
||||||
return self.extract_date_range(raw_query) is not None
|
return self.extract_date_range(raw_query) is not None
|
||||||
|
|
||||||
|
def defilter(self, query):
|
||||||
|
# remove date range filter from query
|
||||||
|
query = re.sub(rf"\s+{self.date_regex}", " ", query)
|
||||||
|
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
|
||||||
|
return query
|
||||||
|
|
||||||
def apply(self, query, entries):
|
def apply(self, query, entries):
|
||||||
"Find entries containing any dates that fall within date range specified in query"
|
"Find entries containing any dates that fall within date range specified in query"
|
||||||
# extract date range specified in date filter of query
|
# extract date range specified in date filter of query
|
||||||
|
@ -59,9 +65,7 @@ class DateFilter(BaseFilter):
|
||||||
if query_daterange is None:
|
if query_daterange is None:
|
||||||
return query, set(range(len(entries)))
|
return query, set(range(len(entries)))
|
||||||
|
|
||||||
# remove date range filter from query
|
query = self.defilter(query)
|
||||||
query = re.sub(rf"\s+{self.date_regex}", " ", query)
|
|
||||||
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
|
|
||||||
|
|
||||||
# return results from cache if exists
|
# return results from cache if exists
|
||||||
cache_key = tuple(query_daterange)
|
cache_key = tuple(query_daterange)
|
||||||
|
|
|
@ -28,6 +28,9 @@ class FileFilter(BaseFilter):
|
||||||
def can_filter(self, raw_query):
|
def can_filter(self, raw_query):
|
||||||
return re.search(self.file_filter_regex, raw_query) is not None
|
return re.search(self.file_filter_regex, raw_query) is not None
|
||||||
|
|
||||||
|
def defilter(self, query: str) -> str:
|
||||||
|
return re.sub(self.file_filter_regex, "", query).strip()
|
||||||
|
|
||||||
def apply(self, query, entries):
|
def apply(self, query, entries):
|
||||||
# Extract file filters from raw query
|
# Extract file filters from raw query
|
||||||
with timer("Extract files_to_search from query", logger):
|
with timer("Extract files_to_search from query", logger):
|
||||||
|
@ -44,8 +47,10 @@ class FileFilter(BaseFilter):
|
||||||
else:
|
else:
|
||||||
files_to_search += [file]
|
files_to_search += [file]
|
||||||
|
|
||||||
|
# Remove filter terms from original query
|
||||||
|
query = self.defilter(query)
|
||||||
|
|
||||||
# Return item from cache if exists
|
# Return item from cache if exists
|
||||||
query = re.sub(self.file_filter_regex, "", query).strip()
|
|
||||||
cache_key = tuple(files_to_search)
|
cache_key = tuple(files_to_search)
|
||||||
if cache_key in self.cache:
|
if cache_key in self.cache:
|
||||||
logger.debug(f"Return file filter results from cache")
|
logger.debug(f"Return file filter results from cache")
|
||||||
|
|
|
@ -43,13 +43,16 @@ class WordFilter(BaseFilter):
|
||||||
|
|
||||||
return len(required_words) != 0 or len(blocked_words) != 0
|
return len(required_words) != 0 or len(blocked_words) != 0
|
||||||
|
|
||||||
|
def defilter(self, query: str) -> str:
|
||||||
|
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
|
||||||
|
|
||||||
def apply(self, query, entries):
|
def apply(self, query, entries):
|
||||||
"Find entries containing required and not blocked words specified in query"
|
"Find entries containing required and not blocked words specified in query"
|
||||||
# Separate natural query from required, blocked words filters
|
# Separate natural query from required, blocked words filters
|
||||||
with timer("Extract required, blocked filters from query", logger):
|
with timer("Extract required, blocked filters from query", logger):
|
||||||
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
|
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
|
||||||
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
|
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
|
||||||
query = re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
|
query = self.defilter(query)
|
||||||
|
|
||||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||||
return query, set(range(len(entries)))
|
return query, set(range(len(entries)))
|
||||||
|
|
|
@ -143,7 +143,7 @@ def extract_metadata(image_name):
|
||||||
return image_processed_metadata
|
return image_processed_metadata
|
||||||
|
|
||||||
|
|
||||||
def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
|
async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
|
||||||
# Set query to image content if query is of form file:/path/to/file.png
|
# Set query to image content if query is of form file:/path/to/file.png
|
||||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||||
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple, Type
|
from typing import List, Tuple, Type, Union
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
|
@ -102,9 +102,10 @@ def compute_embeddings(
|
||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def query(
|
async def query(
|
||||||
raw_query: str,
|
raw_query: str,
|
||||||
model: TextSearchModel,
|
model: TextSearchModel,
|
||||||
|
question_embedding: Union[torch.Tensor, None] = None,
|
||||||
rank_results: bool = False,
|
rank_results: bool = False,
|
||||||
score_threshold: float = -math.inf,
|
score_threshold: float = -math.inf,
|
||||||
dedupe: bool = True,
|
dedupe: bool = True,
|
||||||
|
@ -124,9 +125,10 @@ def query(
|
||||||
return hits, entries
|
return hits, entries
|
||||||
|
|
||||||
# Encode the query using the bi-encoder
|
# Encode the query using the bi-encoder
|
||||||
with timer("Query Encode Time", logger, state.device):
|
if question_embedding is None:
|
||||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
with timer("Query Encode Time", logger, state.device):
|
||||||
question_embedding = util.normalize_embeddings(question_embedding)
|
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||||
|
question_embedding = util.normalize_embeddings(question_embedding)
|
||||||
|
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
with timer("Search Time", logger, state.device):
|
with timer("Search Time", logger, state.device):
|
||||||
|
@ -179,7 +181,7 @@ def setup(
|
||||||
previous_entries = (
|
previous_entries = (
|
||||||
extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
|
extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
|
||||||
)
|
)
|
||||||
entries_with_indices = text_to_jsonl(config).process(previous_entries)
|
entries_with_indices = text_to_jsonl(config).process(previous_entries or [])
|
||||||
|
|
||||||
# Extract Updated Entries
|
# Extract Updated Entries
|
||||||
entries = extract_entries(config.compressed_jsonl)
|
entries = extract_entries(config.compressed_jsonl)
|
||||||
|
|
|
@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
class SearchType(str, Enum):
|
class SearchType(str, Enum):
|
||||||
|
All = "all"
|
||||||
Org = "org"
|
Org = "org"
|
||||||
Ledger = "ledger"
|
Ledger = "ledger"
|
||||||
Music = "music"
|
Music = "music"
|
||||||
|
|
|
@ -34,7 +34,7 @@ def test_search_with_invalid_content_type(client):
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_search_with_valid_content_type(client):
|
def test_search_with_valid_content_type(client):
|
||||||
for content_type in ["org", "markdown", "ledger", "image", "music", "pdf", "plugin1"]:
|
for content_type in ["all", "org", "markdown", "ledger", "image", "music", "pdf", "plugin1"]:
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/api/search?q=random&t={content_type}")
|
response = client.get(f"/api/search?q=random&t={content_type}")
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -84,7 +84,7 @@ def test_get_configured_types_via_api(client):
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == ["org", "image", "plugin1"]
|
assert response.json() == ["all", "org", "image", "plugin1"]
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@ -102,7 +102,7 @@ def test_get_configured_types_with_only_plugin_content_config(content_config):
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == ["plugin1"]
|
assert response.json() == ["all", "plugin1"]
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@ -137,7 +137,7 @@ def test_get_configured_types_with_no_content_config():
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == []
|
assert response.json() == ["all"]
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
|
|
@ -3,6 +3,9 @@ import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
# External Packages
|
||||||
|
import pytest
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.state import model
|
from khoj.utils.state import model
|
||||||
from khoj.utils.constants import web_directory
|
from khoj.utils.constants import web_directory
|
||||||
|
@ -48,7 +51,8 @@ def test_image_metadata(content_config: ContentConfig):
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
|
@pytest.mark.anyio
|
||||||
|
async def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
output_directory = resolve_absolute_path(web_directory)
|
output_directory = resolve_absolute_path(web_directory)
|
||||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||||
|
@ -60,7 +64,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
for query, expected_image_name in query_expected_image_pairs:
|
for query, expected_image_name in query_expected_image_pairs:
|
||||||
hits = image_search.query(query, count=1, model=model.image_search)
|
hits = await image_search.query(query, count=1, model=model.image_search)
|
||||||
|
|
||||||
results = image_search.collate_results(
|
results = image_search.collate_results(
|
||||||
hits,
|
hits,
|
||||||
|
@ -83,7 +87,8 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
@pytest.mark.anyio
|
||||||
|
async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
||||||
# Arrange
|
# Arrange
|
||||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||||
max_words_supported = 10
|
max_words_supported = 10
|
||||||
|
@ -93,7 +98,7 @@ def test_image_search_query_truncated(content_config: ContentConfig, search_conf
|
||||||
# Act
|
# Act
|
||||||
try:
|
try:
|
||||||
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
||||||
image_search.query(query, count=1, model=model.image_search)
|
await image_search.query(query, count=1, model=model.image_search)
|
||||||
# Assert
|
# Assert
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
|
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
|
||||||
|
@ -102,7 +107,8 @@ def test_image_search_query_truncated(content_config: ContentConfig, search_conf
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
@pytest.mark.anyio
|
||||||
|
async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
||||||
# Arrange
|
# Arrange
|
||||||
output_directory = resolve_absolute_path(web_directory)
|
output_directory = resolve_absolute_path(web_directory)
|
||||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||||
|
@ -113,7 +119,7 @@ def test_image_search_by_filepath(content_config: ContentConfig, search_config:
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
||||||
hits = image_search.query(query, count=1, model=model.image_search)
|
hits = await image_search.query(query, count=1, model=model.image_search)
|
||||||
|
|
||||||
results = image_search.collate_results(
|
results = image_search.collate_results(
|
||||||
hits,
|
hits,
|
||||||
|
|
|
@ -72,13 +72,14 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
|
@pytest.mark.anyio
|
||||||
|
async def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||||
query = "How to git install application?"
|
query = "How to git install application?"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
hits, entries = text_search.query(query, model=model.notes_search, rank_results=True)
|
hits, entries = await text_search.query(query, model=model.notes_search, rank_results=True)
|
||||||
|
|
||||||
results = text_search.collate_results(hits, entries, count=1)
|
results = text_search.collate_results(hits, entries, count=1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue