Accept current changes to include issues in rendering flow

This commit is contained in:
sabaimran 2023-06-29 12:25:29 -07:00
commit 6edc32f2f4
19 changed files with 304 additions and 193 deletions

View file

@ -73,15 +73,15 @@ khoj = "khoj.main:run"
[project.optional-dependencies] [project.optional-dependencies]
test = [ test = [
"pytest >= 7.1.2", "pytest >= 7.1.2",
"freezegun >= 1.2.0",
"factory-boy >= 3.2.1",
"trio >= 0.22.0",
] ]
dev = [ dev = [
"khoj-assistant[test]", "khoj-assistant[test]",
"mypy >= 1.0.1", "mypy >= 1.0.1",
"black >= 23.1.0", "black >= 23.1.0",
"pre-commit >= 3.0.4", "pre-commit >= 3.0.4",
"freezegun >= 1.2.0",
"factory-boy==3.2.1",
"Faker==18.10.1",
] ]
[tool.hatch.version] [tool.hatch.version]

View file

@ -598,9 +598,19 @@ CONFIG is json obtained from Khoj config API."
"Convert JSON-RESPONSE, QUERY from API to text entries." "Convert JSON-RESPONSE, QUERY from API to text entries."
(thread-last json-response (thread-last json-response
;; extract and render entries from API response ;; extract and render entries from API response
(mapcar (lambda (args) (format "%s\n\n" (cdr (assoc 'entry args))))) (mapcar (lambda (json-response-item)
(thread-last
;; Extract pdf entry from each item in json response
(cdr (assoc 'entry json-response-item))
(format "%s\n\n")
;; Standardize results to 2nd level heading for consistent rendering
(replace-regexp-in-string "^\*+" "")
;; Standardize results to 2nd level heading for consistent rendering
(replace-regexp-in-string "^\#+" "")
;; Format entries as org entry string
(format "** %s"))))
;; Set query as heading in rendered results buffer ;; Set query as heading in rendered results buffer
(format "# Query: %s\n\n%s\n" query) (format "* %s\n%s\n" query)
;; remove leading (, ) or SPC from extracted entries string ;; remove leading (, ) or SPC from extracted entries string
(replace-regexp-in-string "^[\(\) ]" "") (replace-regexp-in-string "^[\(\) ]" "")
;; remove trailing (, ) or SPC from extracted entries string ;; remove trailing (, ) or SPC from extracted entries string
@ -674,7 +684,8 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE."
((equal content-type "ledger") (khoj--extract-entries-as-ledger json-response query)) ((equal content-type "ledger") (khoj--extract-entries-as-ledger json-response query))
((equal content-type "image") (khoj--extract-entries-as-images json-response query)) ((equal content-type "image") (khoj--extract-entries-as-images json-response query))
(t (khoj--extract-entries json-response query)))) (t (khoj--extract-entries json-response query))))
(cond ((or (equal content-type "pdf") (cond ((or (equal content-type "all")
(equal content-type "pdf")
(equal content-type "org")) (equal content-type "org"))
(progn (visual-line-mode) (progn (visual-line-mode)
(org-mode) (org-mode)
@ -1003,7 +1014,7 @@ Paragraph only starts at first text after blank line."
;; set content type to: last used > based on current buffer > default type ;; set content type to: last used > based on current buffer > default type
:init-value (lambda (obj) (oset obj value (format "--content-type=%s" (or khoj--content-type (khoj--buffer-name-to-content-type (buffer-name)))))) :init-value (lambda (obj) (oset obj value (format "--content-type=%s" (or khoj--content-type (khoj--buffer-name-to-content-type (buffer-name))))))
;; dynamically set choices to content types enabled on khoj backend ;; dynamically set choices to content types enabled on khoj backend
:choices (or (ignore-errors (mapcar #'symbol-name (khoj--get-enabled-content-types))) '("org" "markdown" "pdf" "ledger" "music" "image"))) :choices (or (ignore-errors (mapcar #'symbol-name (khoj--get-enabled-content-types))) '("all" "org" "markdown" "pdf" "ledger" "music" "image")))
(transient-define-suffix khoj--search-command (&optional args) (transient-define-suffix khoj--search-command (&optional args)
(interactive (list (transient-args transient-current-command))) (interactive (list (transient-args transient-current-command)))

View file

@ -3,6 +3,7 @@ import sys
import logging import logging
import json import json
from enum import Enum from enum import Enum
from typing import Optional
import requests import requests
# External Packages # External Packages
@ -36,7 +37,7 @@ def configure_server(args, required=False):
logger.error(f"Exiting as Khoj is not configured.\nConfigure it via GUI or by editing {state.config_file}.") logger.error(f"Exiting as Khoj is not configured.\nConfigure it via GUI or by editing {state.config_file}.")
sys.exit(1) sys.exit(1)
else: else:
logger.warn( logger.warning(
f"Khoj is not configured.\nConfigure it via khoj GUI, plugins or by editing {state.config_file}." f"Khoj is not configured.\nConfigure it via khoj GUI, plugins or by editing {state.config_file}."
) )
return return
@ -78,16 +79,20 @@ def configure_search_types(config: FullConfig):
core_search_types = {e.name: e.value for e in SearchType} core_search_types = {e.name: e.value for e in SearchType}
# Extract configured plugin search types # Extract configured plugin search types
plugin_search_types = {} plugin_search_types = {}
if config.content_type.plugins: if config.content_type and config.content_type.plugins:
plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()} plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()}
# Dynamically generate search type enum by merging core search types with configured plugin search types # Dynamically generate search type enum by merging core search types with configured plugin search types
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types)) return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: state.SearchType = None): def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None):
if config is None or config.content_type is None or config.search_type is None:
logger.warning("🚨 No Content or Search type is configured.")
return
# Initialize Org Notes Search # Initialize Org Notes Search
if (t == state.SearchType.Org or t == None) and config.content_type.org: if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric:
logger.info("🦄 Setting up search for orgmode notes") logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings # Extract Entries, Generate Notes Embeddings
model.org_search = text_search.setup( model.org_search = text_search.setup(
@ -99,7 +104,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
) )
# Initialize Org Music Search # Initialize Org Music Search
if (t == state.SearchType.Music or t == None) and config.content_type.music: if (t == state.SearchType.Music or t == None) and config.content_type.music and config.search_type.asymmetric:
logger.info("🎺 Setting up search for org-music") logger.info("🎺 Setting up search for org-music")
# Extract Entries, Generate Music Embeddings # Extract Entries, Generate Music Embeddings
model.music_search = text_search.setup( model.music_search = text_search.setup(
@ -111,7 +116,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
) )
# Initialize Markdown Search # Initialize Markdown Search
if (t == state.SearchType.Markdown or t == None) and config.content_type.markdown: if (t == state.SearchType.Markdown or t == None) and config.content_type.markdown and config.search_type.asymmetric:
logger.info("💎 Setting up search for markdown notes") logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings # Extract Entries, Generate Markdown Embeddings
model.markdown_search = text_search.setup( model.markdown_search = text_search.setup(
@ -123,7 +128,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
) )
# Initialize Ledger Search # Initialize Ledger Search
if (t == state.SearchType.Ledger or t == None) and config.content_type.ledger: if (t == state.SearchType.Ledger or t == None) and config.content_type.ledger and config.search_type.symmetric:
logger.info("💸 Setting up search for ledger") logger.info("💸 Setting up search for ledger")
# Extract Entries, Generate Ledger Embeddings # Extract Entries, Generate Ledger Embeddings
model.ledger_search = text_search.setup( model.ledger_search = text_search.setup(
@ -135,7 +140,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
) )
# Initialize PDF Search # Initialize PDF Search
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf: if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric:
logger.info("🖨️ Setting up search for pdf") logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings # Extract Entries, Generate PDF Embeddings
model.pdf_search = text_search.setup( model.pdf_search = text_search.setup(
@ -147,14 +152,14 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
) )
# Initialize Image Search # Initialize Image Search
if (t == state.SearchType.Image or t == None) and config.content_type.image: if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image:
logger.info("🌄 Setting up search for images") logger.info("🌄 Setting up search for images")
# Extract Entries, Generate Image Embeddings # Extract Entries, Generate Image Embeddings
model.image_search = image_search.setup( model.image_search = image_search.setup(
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
) )
if (t == state.SearchType.Github or t == None) and config.content_type.github: if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric:
logger.info("🐙 Setting up search for github") logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings # Extract Entries, Generate Github Embeddings
model.github_search = text_search.setup( model.github_search = text_search.setup(

View file

@ -14,11 +14,13 @@
<script> <script>
function render_image(item) { function render_image(item) {
return ` return `
<div class="results-image">
<a href="${item.entry}" class="image-link"> <a href="${item.entry}" class="image-link">
<img id=${item.score} src="${item.entry}?${Math.random()}" <img id=${item.score} src="${item.entry}?${Math.random()}"
title="Effective Score: ${item.score}, Meta: ${item.additional.metadata_score}, Image: ${item.additional.image_score}" title="Effective Score: ${item.score}, Meta: ${item.additional.metadata_score}, Image: ${item.additional.image_score}"
class="image"> class="image">
</a>` </a>
</div>`;
} }
function render_org(query, data, classPrefix="") { function render_org(query, data, classPrefix="") {
@ -28,33 +30,33 @@
var orgParser = new Org.Parser(); var orgParser = new Org.Parser();
var orgDocument = orgParser.parse(orgCode); var orgDocument = orgParser.parse(orgCode);
var orgHTMLDocument = orgDocument.convert(Org.ConverterHTML, { htmlClassPrefix: classPrefix }); var orgHTMLDocument = orgDocument.convert(Org.ConverterHTML, { htmlClassPrefix: classPrefix });
return orgHTMLDocument.toString(); return `<div class="results-org">` + orgHTMLDocument.toString() + `</div>`;
} }
function render_markdown(query, data) { function render_markdown(query, data) {
var md = window.markdownit(); var md = window.markdownit();
return md.render(data.map(function (item) { return data.map(function (item) {
if (item.additional.file.startsWith("http")) { if (item.additional.file.startsWith("http")) {
lines = item.entry.split("\n"); lines = item.entry.split("\n");
return `${lines[0]}\t[*](${item.additional.file})\n${lines.slice(1).join("\n")}`; return md.render(`${lines[0]}\t[*](${item.additional.file})\n${lines.slice(1).join("\n")}`);
} }
return `${item.entry}`; return `<div class="results-markdown">` + md.render(`${item.entry}`) + `</div>`;
}).join("\n")); }).join("\n");
} }
function render_ledger(query, data) { function render_ledger(query, data) {
return `<div id="results-ledger">` + data.map(function (item) { return data.map(function (item) {
return `<p>${item.entry}</p>` return `<div class="results-ledger">` + `<p>${item.entry}</p>` + `</div>`;
}).join("\n") + `</div>`; }).join("\n");
} }
function render_pdf(query, data) { function render_pdf(query, data) {
return `<div id="results-pdf">` + data.map(function (item) { return data.map(function (item) {
let compiled_lines = item.additional.compiled.split("\n"); let compiled_lines = item.additional.compiled.split("\n");
let filename = compiled_lines.shift(); let filename = compiled_lines.shift();
let text_match = compiled_lines.join("\n") let text_match = compiled_lines.join("\n")
return `<h2>${filename}</h2>\n<p>${text_match}</p>` return `<div class="results-pdf">` + `<h2>${filename}</h2>\n<p>${text_match}</p>` + `</div>`;
}).join("\n") + `</div>`; }).join("\n");
} }
function render_mutliple(query, data, type) { function render_mutliple(query, data, type) {
@ -83,26 +85,26 @@
return html; return html;
} }
function render_json(data, query, type) { function render_results(data, query, type) {
let results = "";
if (type === "markdown") { if (type === "markdown") {
return render_markdown(query, data); results = render_markdown(query, data);
} else if (type === "org") { } else if (type === "org") {
return render_org(query, data); results = render_org(query, data, "org-");
} else if (type === "music") { } else if (type === "music") {
return render_org(query, data, "music-"); results = render_org(query, data, "music-");
} else if (type === "image") { } else if (type === "image") {
return data.map(render_image).join(''); results = data.map(render_image).join('');
} else if (type === "ledger") { } else if (type === "ledger") {
return render_ledger(query, data); results = render_ledger(query, data);
} else if (type === "pdf") { } else if (type === "pdf") {
return render_pdf(query, data); results = render_pdf(query, data);
} else if (type == "github") { } else if (type === "github" || type === "all") {
return render_mutliple(query, data, type); results = render_mutliple(query, data, type);
} else { } else {
return `<div id="results-plugin">` results = data.map((item) => `<div class="results-plugin">` + `<p>${item.entry}</p>` + `</div>`).join("\n")
+ data.map((item) => `<p>${item.entry}</p>`).join("\n")
+ `</div>`;
} }
return `<div id="results-${type}">${results}</div>`;
} }
function search(rerank=false) { function search(rerank=false) {
@ -120,20 +122,13 @@
if (rerank) if (rerank)
setQueryFieldInUrl(query); setQueryFieldInUrl(query);
// Generate Backend API URL to execute Search
url = type === "image"
? `/api/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&client=web`
: `/api/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&r=${rerank}&client=web`;
// Execute Search and Render Results // Execute Search and Render Results
url = createRequestUrl(query, type, results_count, rerank);
fetch(url) fetch(url)
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
console.log(data); console.log(data);
document.getElementById("results").innerHTML = document.getElementById("results").innerHTML = render_results(data, query, type);
`<div id=results-${type}>`
+ render_json(data, query, type)
+ `</div>`;
}); });
} }
@ -144,7 +139,7 @@
.then(data => { .then(data => {
console.log(data); console.log(data);
document.getElementById("results").innerHTML = document.getElementById("results").innerHTML =
render_json(data); render_results(data);
}); });
} }
@ -180,6 +175,18 @@
}); });
} }
function createRequestUrl(query, type, results_count, rerank) {
// Generate Backend API URL to execute Search
let url = `/api/search?q=${encodeURIComponent(query)}&n=${results_count}&client=web`;
// If type is not 'all', append type to URL
if (type !== 'all')
url += `&t=${type}`;
// Rerank is only supported by text types
if (type !== "image")
url += `&r=${rerank}`;
return url;
}
function setTypeFieldInUrl(type) { function setTypeFieldInUrl(type) {
var url = new URL(window.location.href); var url = new URL(window.location.href);
url.searchParams.set("t", type.value); url.searchParams.set("t", type.value);
@ -309,7 +316,7 @@
margin: 0px; margin: 0px;
line-height: 20px; line-height: 20px;
} }
#results-image { .results-image {
display: grid; display: grid;
grid-template-columns: repeat(3, 1fr); grid-template-columns: repeat(3, 1fr);
} }
@ -324,27 +331,28 @@
#json { #json {
white-space: pre-wrap; white-space: pre-wrap;
} }
#results-pdf, .results-pdf,
#results-plugin, .results-plugin,
#results-ledger { .results-ledger {
text-align: left; text-align: left;
white-space: pre-line; white-space: pre-line;
} }
#results-markdown, #results-github { .results-markdown,
.results-github {
text-align: left; text-align: left;
} }
#results-music, .results-music,
#results-org { .results-org {
text-align: left; text-align: left;
white-space: pre-line; white-space: pre-line;
} }
#results-music h3, .results-music h3,
#results-org h3 { .results-org h3 {
margin: 20px 0 0 0; margin: 20px 0 0 0;
font-size: larger; font-size: larger;
} }
span.music-task-status, span.music-task-status,
span.task-status { span.org-task-status {
color: white; color: white;
padding: 3.5px 3.5px 0; padding: 3.5px 3.5px 0;
margin-right: 5px; margin-right: 5px;
@ -353,15 +361,15 @@
font-size: medium; font-size: medium;
} }
span.music-task-status.todo, span.music-task-status.todo,
span.task-status.todo { span.org-task-status.todo {
background-color: #3b82f6 background-color: #3b82f6
} }
span.music-task-status.done, span.music-task-status.done,
span.task-status.done { span.org-task-status.done {
background-color: #22c55e; background-color: #22c55e;
} }
span.music-task-tag, span.music-task-tag,
span.task-tag { span.org-task-tag {
color: white; color: white;
padding: 3.5px 3.5px 0; padding: 3.5px 3.5px 0;
margin-right: 5px; margin-right: 5px;

View file

@ -113,7 +113,7 @@ def extract_questions(text, model="text-davinci-003", conversation_log={}, api_k
.replace("', '", '", "') .replace("', '", '", "')
) )
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
logger.warn(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}") logger.warning(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}")
questions = [text] questions = [text]
logger.debug(f"Extracted Questions by GPT: {questions}") logger.debug(f"Extracted Questions by GPT: {questions}")
return questions return questions

View file

@ -92,7 +92,7 @@ class MarkdownToJsonl(TextToJsonl):
} }
if any(files_with_non_markdown_extensions): if any(files_with_non_markdown_extensions):
logger.warn( logger.warning(
f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}" f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}"
) )

View file

@ -88,7 +88,7 @@ class OrgToJsonl(TextToJsonl):
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")} files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
if any(files_with_non_org_extensions): if any(files_with_non_org_extensions):
logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}") logger.warning(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
logger.debug(f"Processing files: {all_org_files}") logger.debug(f"Processing files: {all_org_files}")

View file

@ -83,7 +83,9 @@ class PdfToJsonl(TextToJsonl):
files_with_non_pdf_extensions = {pdf_file for pdf_file in all_pdf_files if not pdf_file.endswith(".pdf")} files_with_non_pdf_extensions = {pdf_file for pdf_file in all_pdf_files if not pdf_file.endswith(".pdf")}
if any(files_with_non_pdf_extensions): if any(files_with_non_pdf_extensions):
logger.warn(f"[Warning] There maybe non pdf-mode files in the input set: {files_with_non_pdf_extensions}") logger.warning(
f"[Warning] There maybe non pdf-mode files in the input set: {files_with_non_pdf_extensions}"
)
logger.debug(f"Processing files: {all_pdf_files}") logger.debug(f"Processing files: {all_pdf_files}")

View file

@ -1,5 +1,7 @@
# Standard Packages # Standard Packages
import concurrent.futures
import math import math
import time
import yaml import yaml
import logging import logging
from datetime import datetime from datetime import datetime
@ -8,12 +10,17 @@ from typing import List, Optional, Union
# External Packages # External Packages
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import HTTPException from fastapi import HTTPException
from sentence_transformers import util
# Internal Packages # Internal Packages
from khoj.configure import configure_processor, configure_search from khoj.configure import configure_processor, configure_search
from khoj.processor.conversation.gpt import converse, extract_questions from khoj.processor.conversation.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log, message_to_prompt from khoj.processor.conversation.utils import message_to_log, message_to_prompt
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils.config import TextSearchModel
from khoj.utils.helpers import log_telemetry, timer from khoj.utils.helpers import log_telemetry, timer
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import (
ContentConfig, ContentConfig,
@ -58,6 +65,7 @@ def get_config_types():
and getattr(state.model, f"{search_type.value}_search") is not None and getattr(state.model, f"{search_type.value}_search") is not None
) )
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"]) or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
or search_type == SearchType.All
] ]
@ -125,24 +133,31 @@ async def set_processor_conversation_config_data(updated_config: ConversationPro
@api.get("/search", response_model=List[SearchResponse]) @api.get("/search", response_model=List[SearchResponse])
def search( async def search(
q: str, q: str,
n: Optional[int] = 5, n: Optional[int] = 5,
t: Optional[SearchType] = None, t: Optional[SearchType] = SearchType.All,
r: Optional[bool] = False, r: Optional[bool] = False,
score_threshold: Optional[Union[float, None]] = None, score_threshold: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True, dedupe: Optional[bool] = True,
client: Optional[str] = None, client: Optional[str] = None,
): ):
start_time = time.time()
# Run validation checks
results: List[SearchResponse] = [] results: List[SearchResponse] = []
if q is None or q == "": if q is None or q == "":
logger.warn(f"No query param (q) passed in API call to initiate search") logger.warning(f"No query param (q) passed in API call to initiate search")
return results
if not state.model or not any(state.model.__dict__.values()):
logger.warning(f"No search models loaded. Configure a search model before initiating search")
return results return results
# initialize variables # initialize variables
user_query = q.strip() user_query = q.strip()
results_count = n results_count = n
score_threshold = score_threshold if score_threshold is not None else -math.inf score_threshold = score_threshold if score_threshold is not None else -math.inf
search_futures: List[concurrent.futures.Future] = []
# return cached results, if available # return cached results, if available
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}" query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
@ -150,105 +165,146 @@ def search(
logger.debug(f"Return response from query cache") logger.debug(f"Return response from query cache")
return state.query_cache[query_cache_key] return state.query_cache[query_cache_key]
if (t == SearchType.Org or t == None) and state.model.org_search: # Encode query with filter terms removed
# query org-mode notes defiltered_query = user_query
for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(user_query)
encoded_asymmetric_query = None
if t == SearchType.All or (t != SearchType.Ledger and t != SearchType.Image):
text_search_models: List[TextSearchModel] = [
model
for model_name, model in state.model.__dict__.items()
if isinstance(model, TextSearchModel) and model_name != "ledger_search"
]
if text_search_models:
with timer("Encoding query took", logger=logger):
encoded_asymmetric_query = util.normalize_embeddings(
text_search_models[0].bi_encoder.encode(
[defiltered_query],
convert_to_tensor=True,
device=state.device,
)
)
with concurrent.futures.ThreadPoolExecutor() as executor:
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search:
# query org-mode notes
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.org_search,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (t == SearchType.Markdown or t == SearchType.All) and state.model.markdown_search:
# query markdown notes
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.markdown_search,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (t == SearchType.Pdf or t == SearchType.All) and state.model.pdf_search:
# query pdf files
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.pdf_search,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (t == SearchType.Ledger) and state.model.ledger_search:
# query transactions
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.ledger_search,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (t == SearchType.Music or t == SearchType.All) and state.model.music_search:
# query music library
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.music_search,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (t == SearchType.Image) and state.model.image_search:
# query images
search_futures += [
executor.submit(
image_search.query,
user_query,
results_count,
state.model.image_search,
score_threshold=score_threshold,
)
]
if (t == SearchType.All or t in SearchType) and state.model.plugin_search:
# query specified plugin type
search_futures += [
executor.submit(
text_search.query,
user_query,
# Get plugin search model for specified search type, or the first one if none specified
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
# Query across each requested content types in parallel
with timer("Query took", logger): with timer("Query took", logger):
hits, entries = text_search.query( for search_future in concurrent.futures.as_completed(search_futures):
user_query, state.model.org_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe if t == SearchType.Image:
) hits = await search_future.result()
output_directory = constants.web_directory / "images"
# Collate results
results += image_search.collate_results(
hits,
image_names=state.model.image_search.image_names,
output_directory=output_directory,
image_files_url="/static/images",
count=results_count or 5,
)
else:
hits, entries = await search_future.result()
# Collate results
results += text_search.collate_results(hits, entries, results_count or 5)
# collate and return results # Sort results across all content types
with timer("Collating results took", logger): results.sort(key=lambda x: float(x.score), reverse=True)
results = text_search.collate_results(hits, entries, results_count)
elif (t == SearchType.Markdown or t == None) and state.model.markdown_search:
# query markdown files
with timer("Query took", logger):
hits, entries = text_search.query(
user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
)
# collate and return results
with timer("Collating results took", logger):
results = text_search.collate_results(hits, entries, results_count)
elif (t == SearchType.Pdf or t == None) and state.model.pdf_search:
# query pdf files
with timer("Query took", logger):
hits, entries = text_search.query(
user_query, state.model.pdf_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
)
# collate and return results
with timer("Collating results took", logger):
results = text_search.collate_results(hits, entries, results_count)
elif (t == SearchType.Github or t == None) and state.model.github_search:
# query github embeddings
with timer("Query took", logger):
hits, entries = text_search.query(
user_query, state.model.github_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
)
# collate and return results
with timer("Collating results took", logger):
results = text_search.collate_results(hits, entries, results_count)
elif (t == SearchType.Ledger or t == None) and state.model.ledger_search:
# query transactions
with timer("Query took", logger):
hits, entries = text_search.query(
user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
)
# collate and return results
with timer("Collating results took", logger):
results = text_search.collate_results(hits, entries, results_count)
elif (t == SearchType.Music or t == None) and state.model.music_search:
# query music library
with timer("Query took", logger):
hits, entries = text_search.query(
user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
)
# collate and return results
with timer("Collating results took", logger):
results = text_search.collate_results(hits, entries, results_count)
elif (t == SearchType.Image or t == None) and state.model.image_search:
# query images
with timer("Query took", logger):
hits = image_search.query(
user_query, results_count, state.model.image_search, score_threshold=score_threshold
)
output_directory = constants.web_directory / "images"
# collate and return results
with timer("Collating results took", logger):
results = image_search.collate_results(
hits,
image_names=state.model.image_search.image_names,
output_directory=output_directory,
image_files_url="/static/images",
count=results_count,
)
elif (t in SearchType or t == None) and state.model.plugin_search:
# query specified plugin type
with timer("Query took", logger):
hits, entries = text_search.query(
user_query,
# Get plugin search model for specified search type, or the first one if none specified
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
rank_results=r,
score_threshold=score_threshold,
dedupe=dedupe,
)
# collate and return results
with timer("Collating results took", logger):
results = text_search.collate_results(hits, entries, results_count)
# Cache results # Cache results
state.query_cache[query_cache_key] = results state.query_cache[query_cache_key] = results
@ -260,6 +316,9 @@ def search(
] ]
state.previous_query = user_query state.previous_query = user_query
end_time = time.time()
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
return results return results
@ -267,7 +326,7 @@ def search(
def update(t: Optional[SearchType] = None, force: Optional[bool] = False, client: Optional[str] = None): def update(t: Optional[SearchType] = None, force: Optional[bool] = False, client: Optional[str] = None):
try: try:
state.search_index_lock.acquire() state.search_index_lock.acquire()
state.model = configure_search(state.model, state.config, regenerate=force, t=t) state.model = configure_search(state.model, state.config, regenerate=force or False, t=t)
state.search_index_lock.release() state.search_index_lock.release()
except ValueError as e: except ValueError as e:
logger.error(e) logger.error(e)

View file

@ -18,3 +18,7 @@ class BaseFilter(ABC):
@abstractmethod @abstractmethod
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]: def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
... ...
@abstractmethod
def defilter(self, query: str) -> str:
...

View file

@ -49,6 +49,12 @@ class DateFilter(BaseFilter):
"Check if query contains date filters" "Check if query contains date filters"
return self.extract_date_range(raw_query) is not None return self.extract_date_range(raw_query) is not None
def defilter(self, query):
# remove date range filter from query
query = re.sub(rf"\s+{self.date_regex}", " ", query)
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
return query
def apply(self, query, entries): def apply(self, query, entries):
"Find entries containing any dates that fall within date range specified in query" "Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query # extract date range specified in date filter of query
@ -59,9 +65,7 @@ class DateFilter(BaseFilter):
if query_daterange is None: if query_daterange is None:
return query, set(range(len(entries))) return query, set(range(len(entries)))
# remove date range filter from query query = self.defilter(query)
query = re.sub(rf"\s+{self.date_regex}", " ", query)
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
# return results from cache if exists # return results from cache if exists
cache_key = tuple(query_daterange) cache_key = tuple(query_daterange)

View file

@ -28,6 +28,9 @@ class FileFilter(BaseFilter):
def can_filter(self, raw_query): def can_filter(self, raw_query):
return re.search(self.file_filter_regex, raw_query) is not None return re.search(self.file_filter_regex, raw_query) is not None
def defilter(self, query: str) -> str:
return re.sub(self.file_filter_regex, "", query).strip()
def apply(self, query, entries): def apply(self, query, entries):
# Extract file filters from raw query # Extract file filters from raw query
with timer("Extract files_to_search from query", logger): with timer("Extract files_to_search from query", logger):
@ -44,8 +47,10 @@ class FileFilter(BaseFilter):
else: else:
files_to_search += [file] files_to_search += [file]
# Remove filter terms from original query
query = self.defilter(query)
# Return item from cache if exists # Return item from cache if exists
query = re.sub(self.file_filter_regex, "", query).strip()
cache_key = tuple(files_to_search) cache_key = tuple(files_to_search)
if cache_key in self.cache: if cache_key in self.cache:
logger.debug(f"Return file filter results from cache") logger.debug(f"Return file filter results from cache")

View file

@ -43,13 +43,16 @@ class WordFilter(BaseFilter):
return len(required_words) != 0 or len(blocked_words) != 0 return len(required_words) != 0 or len(blocked_words) != 0
def defilter(self, query: str) -> str:
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
def apply(self, query, entries): def apply(self, query, entries):
"Find entries containing required and not blocked words specified in query" "Find entries containing required and not blocked words specified in query"
# Separate natural query from required, blocked words filters # Separate natural query from required, blocked words filters
with timer("Extract required, blocked filters from query", logger): with timer("Extract required, blocked filters from query", logger):
required_words = set([word.lower() for word in re.findall(self.required_regex, query)]) required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)]) blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
query = re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip() query = self.defilter(query)
if len(required_words) == 0 and len(blocked_words) == 0: if len(required_words) == 0 and len(blocked_words) == 0:
return query, set(range(len(entries))) return query, set(range(len(entries)))

View file

@ -143,7 +143,7 @@ def extract_metadata(image_name):
return image_processed_metadata return image_processed_metadata
def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf): async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
# Set query to image content if query is of form file:/path/to/file.png # Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file(): if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True) query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)

View file

@ -2,7 +2,7 @@
import logging import logging
import math import math
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Type from typing import List, Tuple, Type, Union
# External Packages # External Packages
import torch import torch
@ -102,9 +102,10 @@ def compute_embeddings(
return corpus_embeddings return corpus_embeddings
def query( async def query(
raw_query: str, raw_query: str,
model: TextSearchModel, model: TextSearchModel,
question_embedding: Union[torch.Tensor, None] = None,
rank_results: bool = False, rank_results: bool = False,
score_threshold: float = -math.inf, score_threshold: float = -math.inf,
dedupe: bool = True, dedupe: bool = True,
@ -124,9 +125,10 @@ def query(
return hits, entries return hits, entries
# Encode the query using the bi-encoder # Encode the query using the bi-encoder
with timer("Query Encode Time", logger, state.device): if question_embedding is None:
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) with timer("Query Encode Time", logger, state.device):
question_embedding = util.normalize_embeddings(question_embedding) question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding)
# Find relevant entries for the query # Find relevant entries for the query
with timer("Search Time", logger, state.device): with timer("Search Time", logger, state.device):
@ -179,7 +181,7 @@ def setup(
previous_entries = ( previous_entries = (
extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
) )
entries_with_indices = text_to_jsonl(config).process(previous_entries) entries_with_indices = text_to_jsonl(config).process(previous_entries or [])
# Extract Updated Entries # Extract Updated Entries
entries = extract_entries(config.compressed_jsonl) entries = extract_entries(config.compressed_jsonl)

View file

@ -17,6 +17,7 @@ if TYPE_CHECKING:
class SearchType(str, Enum): class SearchType(str, Enum):
All = "all"
Org = "org" Org = "org"
Ledger = "ledger" Ledger = "ledger"
Music = "music" Music = "music"

View file

@ -34,7 +34,7 @@ def test_search_with_invalid_content_type(client):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_search_with_valid_content_type(client): def test_search_with_valid_content_type(client):
for content_type in ["org", "markdown", "ledger", "image", "music", "pdf", "plugin1"]: for content_type in ["all", "org", "markdown", "ledger", "image", "music", "pdf", "plugin1"]:
# Act # Act
response = client.get(f"/api/search?q=random&t={content_type}") response = client.get(f"/api/search?q=random&t={content_type}")
# Assert # Assert
@ -84,7 +84,7 @@ def test_get_configured_types_via_api(client):
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == ["org", "image", "plugin1"] assert response.json() == ["all", "org", "image", "plugin1"]
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@ -102,7 +102,7 @@ def test_get_configured_types_with_only_plugin_content_config(content_config):
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == ["plugin1"] assert response.json() == ["all", "plugin1"]
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@ -137,7 +137,7 @@ def test_get_configured_types_with_no_content_config():
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == [] assert response.json() == ["all"]
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------

View file

@ -3,6 +3,9 @@ import logging
from pathlib import Path from pathlib import Path
from PIL import Image from PIL import Image
# External Packages
import pytest
# Internal Packages # Internal Packages
from khoj.utils.state import model from khoj.utils.state import model
from khoj.utils.constants import web_directory from khoj.utils.constants import web_directory
@ -48,7 +51,8 @@ def test_image_metadata(content_config: ContentConfig):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_image_search(content_config: ContentConfig, search_config: SearchConfig): @pytest.mark.anyio
async def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
output_directory = resolve_absolute_path(web_directory) output_directory = resolve_absolute_path(web_directory)
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
@ -60,7 +64,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
# Act # Act
for query, expected_image_name in query_expected_image_pairs: for query, expected_image_name in query_expected_image_pairs:
hits = image_search.query(query, count=1, model=model.image_search) hits = await image_search.query(query, count=1, model=model.image_search)
results = image_search.collate_results( results = image_search.collate_results(
hits, hits,
@ -83,7 +87,8 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog): @pytest.mark.anyio
async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange # Arrange
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
max_words_supported = 10 max_words_supported = 10
@ -93,7 +98,7 @@ def test_image_search_query_truncated(content_config: ContentConfig, search_conf
# Act # Act
try: try:
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
image_search.query(query, count=1, model=model.image_search) await image_search.query(query, count=1, model=model.image_search)
# Assert # Assert
except RuntimeError as e: except RuntimeError as e:
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e): if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
@ -102,7 +107,8 @@ def test_image_search_query_truncated(content_config: ContentConfig, search_conf
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog): @pytest.mark.anyio
async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange # Arrange
output_directory = resolve_absolute_path(web_directory) output_directory = resolve_absolute_path(web_directory)
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
@ -113,7 +119,7 @@ def test_image_search_by_filepath(content_config: ContentConfig, search_config:
# Act # Act
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
hits = image_search.query(query, count=1, model=model.image_search) hits = await image_search.query(query, count=1, model=model.image_search)
results = image_search.collate_results( results = image_search.collate_results(
hits, hits,

View file

@ -72,13 +72,14 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig): @pytest.mark.anyio
async def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
query = "How to git install application?" query = "How to git install application?"
# Act # Act
hits, entries = text_search.query(query, model=model.notes_search, rank_results=True) hits, entries = await text_search.query(query, model=model.notes_search, rank_results=True)
results = text_search.collate_results(hits, entries, count=1) results = text_search.collate_results(hits, entries, count=1)