mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Accept current changes to include issues in rendering flow
This commit is contained in:
commit
6edc32f2f4
19 changed files with 304 additions and 193 deletions
|
@ -73,15 +73,15 @@ khoj = "khoj.main:run"
|
|||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest >= 7.1.2",
|
||||
"freezegun >= 1.2.0",
|
||||
"factory-boy >= 3.2.1",
|
||||
"trio >= 0.22.0",
|
||||
]
|
||||
dev = [
|
||||
"khoj-assistant[test]",
|
||||
"mypy >= 1.0.1",
|
||||
"black >= 23.1.0",
|
||||
"pre-commit >= 3.0.4",
|
||||
"freezegun >= 1.2.0",
|
||||
"factory-boy==3.2.1",
|
||||
"Faker==18.10.1",
|
||||
]
|
||||
|
||||
[tool.hatch.version]
|
||||
|
|
|
@ -598,9 +598,19 @@ CONFIG is json obtained from Khoj config API."
|
|||
"Convert JSON-RESPONSE, QUERY from API to text entries."
|
||||
(thread-last json-response
|
||||
;; extract and render entries from API response
|
||||
(mapcar (lambda (args) (format "%s\n\n" (cdr (assoc 'entry args)))))
|
||||
(mapcar (lambda (json-response-item)
|
||||
(thread-last
|
||||
;; Extract pdf entry from each item in json response
|
||||
(cdr (assoc 'entry json-response-item))
|
||||
(format "%s\n\n")
|
||||
;; Standardize results to 2nd level heading for consistent rendering
|
||||
(replace-regexp-in-string "^\*+" "")
|
||||
;; Standardize results to 2nd level heading for consistent rendering
|
||||
(replace-regexp-in-string "^\#+" "")
|
||||
;; Format entries as org entry string
|
||||
(format "** %s"))))
|
||||
;; Set query as heading in rendered results buffer
|
||||
(format "# Query: %s\n\n%s\n" query)
|
||||
(format "* %s\n%s\n" query)
|
||||
;; remove leading (, ) or SPC from extracted entries string
|
||||
(replace-regexp-in-string "^[\(\) ]" "")
|
||||
;; remove trailing (, ) or SPC from extracted entries string
|
||||
|
@ -674,7 +684,8 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE."
|
|||
((equal content-type "ledger") (khoj--extract-entries-as-ledger json-response query))
|
||||
((equal content-type "image") (khoj--extract-entries-as-images json-response query))
|
||||
(t (khoj--extract-entries json-response query))))
|
||||
(cond ((or (equal content-type "pdf")
|
||||
(cond ((or (equal content-type "all")
|
||||
(equal content-type "pdf")
|
||||
(equal content-type "org"))
|
||||
(progn (visual-line-mode)
|
||||
(org-mode)
|
||||
|
@ -1003,7 +1014,7 @@ Paragraph only starts at first text after blank line."
|
|||
;; set content type to: last used > based on current buffer > default type
|
||||
:init-value (lambda (obj) (oset obj value (format "--content-type=%s" (or khoj--content-type (khoj--buffer-name-to-content-type (buffer-name))))))
|
||||
;; dynamically set choices to content types enabled on khoj backend
|
||||
:choices (or (ignore-errors (mapcar #'symbol-name (khoj--get-enabled-content-types))) '("org" "markdown" "pdf" "ledger" "music" "image")))
|
||||
:choices (or (ignore-errors (mapcar #'symbol-name (khoj--get-enabled-content-types))) '("all" "org" "markdown" "pdf" "ledger" "music" "image")))
|
||||
|
||||
(transient-define-suffix khoj--search-command (&optional args)
|
||||
(interactive (list (transient-args transient-current-command)))
|
||||
|
|
|
@ -3,6 +3,7 @@ import sys
|
|||
import logging
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
import requests
|
||||
|
||||
# External Packages
|
||||
|
@ -36,7 +37,7 @@ def configure_server(args, required=False):
|
|||
logger.error(f"Exiting as Khoj is not configured.\nConfigure it via GUI or by editing {state.config_file}.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Khoj is not configured.\nConfigure it via khoj GUI, plugins or by editing {state.config_file}."
|
||||
)
|
||||
return
|
||||
|
@ -78,16 +79,20 @@ def configure_search_types(config: FullConfig):
|
|||
core_search_types = {e.name: e.value for e in SearchType}
|
||||
# Extract configured plugin search types
|
||||
plugin_search_types = {}
|
||||
if config.content_type.plugins:
|
||||
if config.content_type and config.content_type.plugins:
|
||||
plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()}
|
||||
|
||||
# Dynamically generate search type enum by merging core search types with configured plugin search types
|
||||
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
||||
|
||||
|
||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: state.SearchType = None):
|
||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None):
|
||||
if config is None or config.content_type is None or config.search_type is None:
|
||||
logger.warning("🚨 No Content or Search type is configured.")
|
||||
return
|
||||
|
||||
# Initialize Org Notes Search
|
||||
if (t == state.SearchType.Org or t == None) and config.content_type.org:
|
||||
if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric:
|
||||
logger.info("🦄 Setting up search for orgmode notes")
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
model.org_search = text_search.setup(
|
||||
|
@ -99,7 +104,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
|||
)
|
||||
|
||||
# Initialize Org Music Search
|
||||
if (t == state.SearchType.Music or t == None) and config.content_type.music:
|
||||
if (t == state.SearchType.Music or t == None) and config.content_type.music and config.search_type.asymmetric:
|
||||
logger.info("🎺 Setting up search for org-music")
|
||||
# Extract Entries, Generate Music Embeddings
|
||||
model.music_search = text_search.setup(
|
||||
|
@ -111,7 +116,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
|||
)
|
||||
|
||||
# Initialize Markdown Search
|
||||
if (t == state.SearchType.Markdown or t == None) and config.content_type.markdown:
|
||||
if (t == state.SearchType.Markdown or t == None) and config.content_type.markdown and config.search_type.asymmetric:
|
||||
logger.info("💎 Setting up search for markdown notes")
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
model.markdown_search = text_search.setup(
|
||||
|
@ -123,7 +128,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
|||
)
|
||||
|
||||
# Initialize Ledger Search
|
||||
if (t == state.SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||
if (t == state.SearchType.Ledger or t == None) and config.content_type.ledger and config.search_type.symmetric:
|
||||
logger.info("💸 Setting up search for ledger")
|
||||
# Extract Entries, Generate Ledger Embeddings
|
||||
model.ledger_search = text_search.setup(
|
||||
|
@ -135,7 +140,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
|||
)
|
||||
|
||||
# Initialize PDF Search
|
||||
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf:
|
||||
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric:
|
||||
logger.info("🖨️ Setting up search for pdf")
|
||||
# Extract Entries, Generate PDF Embeddings
|
||||
model.pdf_search = text_search.setup(
|
||||
|
@ -147,14 +152,14 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
|||
)
|
||||
|
||||
# Initialize Image Search
|
||||
if (t == state.SearchType.Image or t == None) and config.content_type.image:
|
||||
if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image:
|
||||
logger.info("🌄 Setting up search for images")
|
||||
# Extract Entries, Generate Image Embeddings
|
||||
model.image_search = image_search.setup(
|
||||
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
|
||||
)
|
||||
|
||||
if (t == state.SearchType.Github or t == None) and config.content_type.github:
|
||||
if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric:
|
||||
logger.info("🐙 Setting up search for github")
|
||||
# Extract Entries, Generate Github Embeddings
|
||||
model.github_search = text_search.setup(
|
||||
|
|
|
@ -14,11 +14,13 @@
|
|||
<script>
|
||||
function render_image(item) {
|
||||
return `
|
||||
<div class="results-image">
|
||||
<a href="${item.entry}" class="image-link">
|
||||
<img id=${item.score} src="${item.entry}?${Math.random()}"
|
||||
title="Effective Score: ${item.score}, Meta: ${item.additional.metadata_score}, Image: ${item.additional.image_score}"
|
||||
class="image">
|
||||
</a>`
|
||||
</a>
|
||||
</div>`;
|
||||
}
|
||||
|
||||
function render_org(query, data, classPrefix="") {
|
||||
|
@ -28,33 +30,33 @@
|
|||
var orgParser = new Org.Parser();
|
||||
var orgDocument = orgParser.parse(orgCode);
|
||||
var orgHTMLDocument = orgDocument.convert(Org.ConverterHTML, { htmlClassPrefix: classPrefix });
|
||||
return orgHTMLDocument.toString();
|
||||
return `<div class="results-org">` + orgHTMLDocument.toString() + `</div>`;
|
||||
}
|
||||
|
||||
function render_markdown(query, data) {
|
||||
var md = window.markdownit();
|
||||
return md.render(data.map(function (item) {
|
||||
return data.map(function (item) {
|
||||
if (item.additional.file.startsWith("http")) {
|
||||
lines = item.entry.split("\n");
|
||||
return `${lines[0]}\t[*](${item.additional.file})\n${lines.slice(1).join("\n")}`;
|
||||
return md.render(`${lines[0]}\t[*](${item.additional.file})\n${lines.slice(1).join("\n")}`);
|
||||
}
|
||||
return `${item.entry}`;
|
||||
}).join("\n"));
|
||||
return `<div class="results-markdown">` + md.render(`${item.entry}`) + `</div>`;
|
||||
}).join("\n");
|
||||
}
|
||||
|
||||
function render_ledger(query, data) {
|
||||
return `<div id="results-ledger">` + data.map(function (item) {
|
||||
return `<p>${item.entry}</p>`
|
||||
}).join("\n") + `</div>`;
|
||||
return data.map(function (item) {
|
||||
return `<div class="results-ledger">` + `<p>${item.entry}</p>` + `</div>`;
|
||||
}).join("\n");
|
||||
}
|
||||
|
||||
function render_pdf(query, data) {
|
||||
return `<div id="results-pdf">` + data.map(function (item) {
|
||||
return data.map(function (item) {
|
||||
let compiled_lines = item.additional.compiled.split("\n");
|
||||
let filename = compiled_lines.shift();
|
||||
let text_match = compiled_lines.join("\n")
|
||||
return `<h2>${filename}</h2>\n<p>${text_match}</p>`
|
||||
}).join("\n") + `</div>`;
|
||||
return `<div class="results-pdf">` + `<h2>${filename}</h2>\n<p>${text_match}</p>` + `</div>`;
|
||||
}).join("\n");
|
||||
}
|
||||
|
||||
function render_mutliple(query, data, type) {
|
||||
|
@ -83,26 +85,26 @@
|
|||
return html;
|
||||
}
|
||||
|
||||
function render_json(data, query, type) {
|
||||
function render_results(data, query, type) {
|
||||
let results = "";
|
||||
if (type === "markdown") {
|
||||
return render_markdown(query, data);
|
||||
results = render_markdown(query, data);
|
||||
} else if (type === "org") {
|
||||
return render_org(query, data);
|
||||
results = render_org(query, data, "org-");
|
||||
} else if (type === "music") {
|
||||
return render_org(query, data, "music-");
|
||||
results = render_org(query, data, "music-");
|
||||
} else if (type === "image") {
|
||||
return data.map(render_image).join('');
|
||||
results = data.map(render_image).join('');
|
||||
} else if (type === "ledger") {
|
||||
return render_ledger(query, data);
|
||||
results = render_ledger(query, data);
|
||||
} else if (type === "pdf") {
|
||||
return render_pdf(query, data);
|
||||
} else if (type == "github") {
|
||||
return render_mutliple(query, data, type);
|
||||
results = render_pdf(query, data);
|
||||
} else if (type === "github" || type === "all") {
|
||||
results = render_mutliple(query, data, type);
|
||||
} else {
|
||||
return `<div id="results-plugin">`
|
||||
+ data.map((item) => `<p>${item.entry}</p>`).join("\n")
|
||||
+ `</div>`;
|
||||
results = data.map((item) => `<div class="results-plugin">` + `<p>${item.entry}</p>` + `</div>`).join("\n")
|
||||
}
|
||||
return `<div id="results-${type}">${results}</div>`;
|
||||
}
|
||||
|
||||
function search(rerank=false) {
|
||||
|
@ -120,20 +122,13 @@
|
|||
if (rerank)
|
||||
setQueryFieldInUrl(query);
|
||||
|
||||
// Generate Backend API URL to execute Search
|
||||
url = type === "image"
|
||||
? `/api/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&client=web`
|
||||
: `/api/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&r=${rerank}&client=web`;
|
||||
|
||||
// Execute Search and Render Results
|
||||
url = createRequestUrl(query, type, results_count, rerank);
|
||||
fetch(url)
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log(data);
|
||||
document.getElementById("results").innerHTML =
|
||||
`<div id=results-${type}>`
|
||||
+ render_json(data, query, type)
|
||||
+ `</div>`;
|
||||
document.getElementById("results").innerHTML = render_results(data, query, type);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -144,7 +139,7 @@
|
|||
.then(data => {
|
||||
console.log(data);
|
||||
document.getElementById("results").innerHTML =
|
||||
render_json(data);
|
||||
render_results(data);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -180,6 +175,18 @@
|
|||
});
|
||||
}
|
||||
|
||||
function createRequestUrl(query, type, results_count, rerank) {
|
||||
// Generate Backend API URL to execute Search
|
||||
let url = `/api/search?q=${encodeURIComponent(query)}&n=${results_count}&client=web`;
|
||||
// If type is not 'all', append type to URL
|
||||
if (type !== 'all')
|
||||
url += `&t=${type}`;
|
||||
// Rerank is only supported by text types
|
||||
if (type !== "image")
|
||||
url += `&r=${rerank}`;
|
||||
return url;
|
||||
}
|
||||
|
||||
function setTypeFieldInUrl(type) {
|
||||
var url = new URL(window.location.href);
|
||||
url.searchParams.set("t", type.value);
|
||||
|
@ -309,7 +316,7 @@
|
|||
margin: 0px;
|
||||
line-height: 20px;
|
||||
}
|
||||
#results-image {
|
||||
.results-image {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(3, 1fr);
|
||||
}
|
||||
|
@ -324,27 +331,28 @@
|
|||
#json {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
#results-pdf,
|
||||
#results-plugin,
|
||||
#results-ledger {
|
||||
.results-pdf,
|
||||
.results-plugin,
|
||||
.results-ledger {
|
||||
text-align: left;
|
||||
white-space: pre-line;
|
||||
}
|
||||
#results-markdown, #results-github {
|
||||
.results-markdown,
|
||||
.results-github {
|
||||
text-align: left;
|
||||
}
|
||||
#results-music,
|
||||
#results-org {
|
||||
.results-music,
|
||||
.results-org {
|
||||
text-align: left;
|
||||
white-space: pre-line;
|
||||
}
|
||||
#results-music h3,
|
||||
#results-org h3 {
|
||||
.results-music h3,
|
||||
.results-org h3 {
|
||||
margin: 20px 0 0 0;
|
||||
font-size: larger;
|
||||
}
|
||||
span.music-task-status,
|
||||
span.task-status {
|
||||
span.org-task-status {
|
||||
color: white;
|
||||
padding: 3.5px 3.5px 0;
|
||||
margin-right: 5px;
|
||||
|
@ -353,15 +361,15 @@
|
|||
font-size: medium;
|
||||
}
|
||||
span.music-task-status.todo,
|
||||
span.task-status.todo {
|
||||
span.org-task-status.todo {
|
||||
background-color: #3b82f6
|
||||
}
|
||||
span.music-task-status.done,
|
||||
span.task-status.done {
|
||||
span.org-task-status.done {
|
||||
background-color: #22c55e;
|
||||
}
|
||||
span.music-task-tag,
|
||||
span.task-tag {
|
||||
span.org-task-tag {
|
||||
color: white;
|
||||
padding: 3.5px 3.5px 0;
|
||||
margin-right: 5px;
|
||||
|
|
|
@ -113,7 +113,7 @@ def extract_questions(text, model="text-davinci-003", conversation_log={}, api_k
|
|||
.replace("', '", '", "')
|
||||
)
|
||||
except json.decoder.JSONDecodeError:
|
||||
logger.warn(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||
logger.warning(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||
questions = [text]
|
||||
logger.debug(f"Extracted Questions by GPT: {questions}")
|
||||
return questions
|
||||
|
|
|
@ -92,7 +92,7 @@ class MarkdownToJsonl(TextToJsonl):
|
|||
}
|
||||
|
||||
if any(files_with_non_markdown_extensions):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}"
|
||||
)
|
||||
|
||||
|
|
|
@ -88,7 +88,7 @@ class OrgToJsonl(TextToJsonl):
|
|||
|
||||
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
|
||||
if any(files_with_non_org_extensions):
|
||||
logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
|
||||
logger.warning(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
|
||||
|
||||
logger.debug(f"Processing files: {all_org_files}")
|
||||
|
||||
|
|
|
@ -83,7 +83,9 @@ class PdfToJsonl(TextToJsonl):
|
|||
files_with_non_pdf_extensions = {pdf_file for pdf_file in all_pdf_files if not pdf_file.endswith(".pdf")}
|
||||
|
||||
if any(files_with_non_pdf_extensions):
|
||||
logger.warn(f"[Warning] There maybe non pdf-mode files in the input set: {files_with_non_pdf_extensions}")
|
||||
logger.warning(
|
||||
f"[Warning] There maybe non pdf-mode files in the input set: {files_with_non_pdf_extensions}"
|
||||
)
|
||||
|
||||
logger.debug(f"Processing files: {all_pdf_files}")
|
||||
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Standard Packages
|
||||
import concurrent.futures
|
||||
import math
|
||||
import time
|
||||
import yaml
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
@ -8,12 +10,17 @@ from typing import List, Optional, Union
|
|||
# External Packages
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from sentence_transformers import util
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_processor, configure_search
|
||||
from khoj.processor.conversation.gpt import converse, extract_questions
|
||||
from khoj.processor.conversation.utils import message_to_log, message_to_prompt
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils.config import TextSearchModel
|
||||
from khoj.utils.helpers import log_telemetry, timer
|
||||
from khoj.utils.rawconfig import (
|
||||
ContentConfig,
|
||||
|
@ -58,6 +65,7 @@ def get_config_types():
|
|||
and getattr(state.model, f"{search_type.value}_search") is not None
|
||||
)
|
||||
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
|
||||
or search_type == SearchType.All
|
||||
]
|
||||
|
||||
|
||||
|
@ -125,24 +133,31 @@ async def set_processor_conversation_config_data(updated_config: ConversationPro
|
|||
|
||||
|
||||
@api.get("/search", response_model=List[SearchResponse])
|
||||
def search(
|
||||
async def search(
|
||||
q: str,
|
||||
n: Optional[int] = 5,
|
||||
t: Optional[SearchType] = None,
|
||||
t: Optional[SearchType] = SearchType.All,
|
||||
r: Optional[bool] = False,
|
||||
score_threshold: Optional[Union[float, None]] = None,
|
||||
dedupe: Optional[bool] = True,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
start_time = time.time()
|
||||
|
||||
# Run validation checks
|
||||
results: List[SearchResponse] = []
|
||||
if q is None or q == "":
|
||||
logger.warn(f"No query param (q) passed in API call to initiate search")
|
||||
logger.warning(f"No query param (q) passed in API call to initiate search")
|
||||
return results
|
||||
if not state.model or not any(state.model.__dict__.values()):
|
||||
logger.warning(f"No search models loaded. Configure a search model before initiating search")
|
||||
return results
|
||||
|
||||
# initialize variables
|
||||
user_query = q.strip()
|
||||
results_count = n
|
||||
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
||||
search_futures: List[concurrent.futures.Future] = []
|
||||
|
||||
# return cached results, if available
|
||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
||||
|
@ -150,105 +165,146 @@ def search(
|
|||
logger.debug(f"Return response from query cache")
|
||||
return state.query_cache[query_cache_key]
|
||||
|
||||
if (t == SearchType.Org or t == None) and state.model.org_search:
|
||||
# query org-mode notes
|
||||
# Encode query with filter terms removed
|
||||
defiltered_query = user_query
|
||||
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
||||
defiltered_query = filter.defilter(user_query)
|
||||
|
||||
encoded_asymmetric_query = None
|
||||
if t == SearchType.All or (t != SearchType.Ledger and t != SearchType.Image):
|
||||
text_search_models: List[TextSearchModel] = [
|
||||
model
|
||||
for model_name, model in state.model.__dict__.items()
|
||||
if isinstance(model, TextSearchModel) and model_name != "ledger_search"
|
||||
]
|
||||
if text_search_models:
|
||||
with timer("Encoding query took", logger=logger):
|
||||
encoded_asymmetric_query = util.normalize_embeddings(
|
||||
text_search_models[0].bi_encoder.encode(
|
||||
[defiltered_query],
|
||||
convert_to_tensor=True,
|
||||
device=state.device,
|
||||
)
|
||||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search:
|
||||
# query org-mode notes
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.org_search,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Markdown or t == SearchType.All) and state.model.markdown_search:
|
||||
# query markdown notes
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.markdown_search,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Pdf or t == SearchType.All) and state.model.pdf_search:
|
||||
# query pdf files
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.pdf_search,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Ledger) and state.model.ledger_search:
|
||||
# query transactions
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.ledger_search,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Music or t == SearchType.All) and state.model.music_search:
|
||||
# query music library
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.music_search,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Image) and state.model.image_search:
|
||||
# query images
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
image_search.query,
|
||||
user_query,
|
||||
results_count,
|
||||
state.model.image_search,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.All or t in SearchType) and state.model.plugin_search:
|
||||
# query specified plugin type
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
# Get plugin search model for specified search type, or the first one if none specified
|
||||
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
# Query across each requested content types in parallel
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(
|
||||
user_query, state.model.org_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
||||
)
|
||||
for search_future in concurrent.futures.as_completed(search_futures):
|
||||
if t == SearchType.Image:
|
||||
hits = await search_future.result()
|
||||
output_directory = constants.web_directory / "images"
|
||||
# Collate results
|
||||
results += image_search.collate_results(
|
||||
hits,
|
||||
image_names=state.model.image_search.image_names,
|
||||
output_directory=output_directory,
|
||||
image_files_url="/static/images",
|
||||
count=results_count or 5,
|
||||
)
|
||||
else:
|
||||
hits, entries = await search_future.result()
|
||||
# Collate results
|
||||
results += text_search.collate_results(hits, entries, results_count or 5)
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
results = text_search.collate_results(hits, entries, results_count)
|
||||
|
||||
elif (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
||||
# query markdown files
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(
|
||||
user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
results = text_search.collate_results(hits, entries, results_count)
|
||||
|
||||
elif (t == SearchType.Pdf or t == None) and state.model.pdf_search:
|
||||
# query pdf files
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(
|
||||
user_query, state.model.pdf_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
results = text_search.collate_results(hits, entries, results_count)
|
||||
|
||||
elif (t == SearchType.Github or t == None) and state.model.github_search:
|
||||
# query github embeddings
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(
|
||||
user_query, state.model.github_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
results = text_search.collate_results(hits, entries, results_count)
|
||||
|
||||
elif (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
||||
# query transactions
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(
|
||||
user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
results = text_search.collate_results(hits, entries, results_count)
|
||||
|
||||
elif (t == SearchType.Music or t == None) and state.model.music_search:
|
||||
# query music library
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(
|
||||
user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
results = text_search.collate_results(hits, entries, results_count)
|
||||
|
||||
elif (t == SearchType.Image or t == None) and state.model.image_search:
|
||||
# query images
|
||||
with timer("Query took", logger):
|
||||
hits = image_search.query(
|
||||
user_query, results_count, state.model.image_search, score_threshold=score_threshold
|
||||
)
|
||||
output_directory = constants.web_directory / "images"
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
results = image_search.collate_results(
|
||||
hits,
|
||||
image_names=state.model.image_search.image_names,
|
||||
output_directory=output_directory,
|
||||
image_files_url="/static/images",
|
||||
count=results_count,
|
||||
)
|
||||
|
||||
elif (t in SearchType or t == None) and state.model.plugin_search:
|
||||
# query specified plugin type
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(
|
||||
user_query,
|
||||
# Get plugin search model for specified search type, or the first one if none specified
|
||||
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
||||
rank_results=r,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe,
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
results = text_search.collate_results(hits, entries, results_count)
|
||||
# Sort results across all content types
|
||||
results.sort(key=lambda x: float(x.score), reverse=True)
|
||||
|
||||
# Cache results
|
||||
state.query_cache[query_cache_key] = results
|
||||
|
@ -260,6 +316,9 @@ def search(
|
|||
]
|
||||
state.previous_query = user_query
|
||||
|
||||
end_time = time.time()
|
||||
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
@ -267,7 +326,7 @@ def search(
|
|||
def update(t: Optional[SearchType] = None, force: Optional[bool] = False, client: Optional[str] = None):
|
||||
try:
|
||||
state.search_index_lock.acquire()
|
||||
state.model = configure_search(state.model, state.config, regenerate=force, t=t)
|
||||
state.model = configure_search(state.model, state.config, regenerate=force or False, t=t)
|
||||
state.search_index_lock.release()
|
||||
except ValueError as e:
|
||||
logger.error(e)
|
||||
|
|
|
@ -18,3 +18,7 @@ class BaseFilter(ABC):
|
|||
@abstractmethod
|
||||
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def defilter(self, query: str) -> str:
|
||||
...
|
||||
|
|
|
@ -49,6 +49,12 @@ class DateFilter(BaseFilter):
|
|||
"Check if query contains date filters"
|
||||
return self.extract_date_range(raw_query) is not None
|
||||
|
||||
def defilter(self, query):
|
||||
# remove date range filter from query
|
||||
query = re.sub(rf"\s+{self.date_regex}", " ", query)
|
||||
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
|
||||
return query
|
||||
|
||||
def apply(self, query, entries):
|
||||
"Find entries containing any dates that fall within date range specified in query"
|
||||
# extract date range specified in date filter of query
|
||||
|
@ -59,9 +65,7 @@ class DateFilter(BaseFilter):
|
|||
if query_daterange is None:
|
||||
return query, set(range(len(entries)))
|
||||
|
||||
# remove date range filter from query
|
||||
query = re.sub(rf"\s+{self.date_regex}", " ", query)
|
||||
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
|
||||
query = self.defilter(query)
|
||||
|
||||
# return results from cache if exists
|
||||
cache_key = tuple(query_daterange)
|
||||
|
|
|
@ -28,6 +28,9 @@ class FileFilter(BaseFilter):
|
|||
def can_filter(self, raw_query):
|
||||
return re.search(self.file_filter_regex, raw_query) is not None
|
||||
|
||||
def defilter(self, query: str) -> str:
|
||||
return re.sub(self.file_filter_regex, "", query).strip()
|
||||
|
||||
def apply(self, query, entries):
|
||||
# Extract file filters from raw query
|
||||
with timer("Extract files_to_search from query", logger):
|
||||
|
@ -44,8 +47,10 @@ class FileFilter(BaseFilter):
|
|||
else:
|
||||
files_to_search += [file]
|
||||
|
||||
# Remove filter terms from original query
|
||||
query = self.defilter(query)
|
||||
|
||||
# Return item from cache if exists
|
||||
query = re.sub(self.file_filter_regex, "", query).strip()
|
||||
cache_key = tuple(files_to_search)
|
||||
if cache_key in self.cache:
|
||||
logger.debug(f"Return file filter results from cache")
|
||||
|
|
|
@ -43,13 +43,16 @@ class WordFilter(BaseFilter):
|
|||
|
||||
return len(required_words) != 0 or len(blocked_words) != 0
|
||||
|
||||
def defilter(self, query: str) -> str:
|
||||
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
|
||||
|
||||
def apply(self, query, entries):
|
||||
"Find entries containing required and not blocked words specified in query"
|
||||
# Separate natural query from required, blocked words filters
|
||||
with timer("Extract required, blocked filters from query", logger):
|
||||
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
|
||||
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
|
||||
query = re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
|
||||
query = self.defilter(query)
|
||||
|
||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||
return query, set(range(len(entries)))
|
||||
|
|
|
@ -143,7 +143,7 @@ def extract_metadata(image_name):
|
|||
return image_processed_metadata
|
||||
|
||||
|
||||
def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
|
||||
async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
|
||||
# Set query to image content if query is of form file:/path/to/file.png
|
||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Type
|
||||
from typing import List, Tuple, Type, Union
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
@ -102,9 +102,10 @@ def compute_embeddings(
|
|||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(
|
||||
async def query(
|
||||
raw_query: str,
|
||||
model: TextSearchModel,
|
||||
question_embedding: Union[torch.Tensor, None] = None,
|
||||
rank_results: bool = False,
|
||||
score_threshold: float = -math.inf,
|
||||
dedupe: bool = True,
|
||||
|
@ -124,9 +125,10 @@ def query(
|
|||
return hits, entries
|
||||
|
||||
# Encode the query using the bi-encoder
|
||||
with timer("Query Encode Time", logger, state.device):
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
if question_embedding is None:
|
||||
with timer("Query Encode Time", logger, state.device):
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
|
||||
# Find relevant entries for the query
|
||||
with timer("Search Time", logger, state.device):
|
||||
|
@ -179,7 +181,7 @@ def setup(
|
|||
previous_entries = (
|
||||
extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
|
||||
)
|
||||
entries_with_indices = text_to_jsonl(config).process(previous_entries)
|
||||
entries_with_indices = text_to_jsonl(config).process(previous_entries or [])
|
||||
|
||||
# Extract Updated Entries
|
||||
entries = extract_entries(config.compressed_jsonl)
|
||||
|
|
|
@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
All = "all"
|
||||
Org = "org"
|
||||
Ledger = "ledger"
|
||||
Music = "music"
|
||||
|
|
|
@ -34,7 +34,7 @@ def test_search_with_invalid_content_type(client):
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_search_with_valid_content_type(client):
|
||||
for content_type in ["org", "markdown", "ledger", "image", "music", "pdf", "plugin1"]:
|
||||
for content_type in ["all", "org", "markdown", "ledger", "image", "music", "pdf", "plugin1"]:
|
||||
# Act
|
||||
response = client.get(f"/api/search?q=random&t={content_type}")
|
||||
# Assert
|
||||
|
@ -84,7 +84,7 @@ def test_get_configured_types_via_api(client):
|
|||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["org", "image", "plugin1"]
|
||||
assert response.json() == ["all", "org", "image", "plugin1"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -102,7 +102,7 @@ def test_get_configured_types_with_only_plugin_content_config(content_config):
|
|||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["plugin1"]
|
||||
assert response.json() == ["all", "plugin1"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -137,7 +137,7 @@ def test_get_configured_types_with_no_content_config():
|
|||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
assert response.json() == ["all"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
|
|
@ -3,6 +3,9 @@ import logging
|
|||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
# External Packages
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.state import model
|
||||
from khoj.utils.constants import web_directory
|
||||
|
@ -48,7 +51,8 @@ def test_image_metadata(content_config: ContentConfig):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
@pytest.mark.anyio
|
||||
async def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
output_directory = resolve_absolute_path(web_directory)
|
||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||
|
@ -60,7 +64,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
|
|||
|
||||
# Act
|
||||
for query, expected_image_name in query_expected_image_pairs:
|
||||
hits = image_search.query(query, count=1, model=model.image_search)
|
||||
hits = await image_search.query(query, count=1, model=model.image_search)
|
||||
|
||||
results = image_search.collate_results(
|
||||
hits,
|
||||
|
@ -83,7 +87,8 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
||||
@pytest.mark.anyio
|
||||
async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
||||
# Arrange
|
||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||
max_words_supported = 10
|
||||
|
@ -93,7 +98,7 @@ def test_image_search_query_truncated(content_config: ContentConfig, search_conf
|
|||
# Act
|
||||
try:
|
||||
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
||||
image_search.query(query, count=1, model=model.image_search)
|
||||
await image_search.query(query, count=1, model=model.image_search)
|
||||
# Assert
|
||||
except RuntimeError as e:
|
||||
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
|
||||
|
@ -102,7 +107,8 @@ def test_image_search_query_truncated(content_config: ContentConfig, search_conf
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
||||
@pytest.mark.anyio
|
||||
async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
||||
# Arrange
|
||||
output_directory = resolve_absolute_path(web_directory)
|
||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||
|
@ -113,7 +119,7 @@ def test_image_search_by_filepath(content_config: ContentConfig, search_config:
|
|||
|
||||
# Act
|
||||
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
||||
hits = image_search.query(query, count=1, model=model.image_search)
|
||||
hits = await image_search.query(query, count=1, model=model.image_search)
|
||||
|
||||
results = image_search.collate_results(
|
||||
hits,
|
||||
|
|
|
@ -72,13 +72,14 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
@pytest.mark.anyio
|
||||
async def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||
query = "How to git install application?"
|
||||
|
||||
# Act
|
||||
hits, entries = text_search.query(query, model=model.notes_search, rank_results=True)
|
||||
hits, entries = await text_search.query(query, model=model.notes_search, rank_results=True)
|
||||
|
||||
results = text_search.collate_results(hits, entries, count=1)
|
||||
|
||||
|
|
Loading…
Reference in a new issue