Move to a push-first model for retrieving embeddings from local files (#457)

* Initial version - setup a file-push architecture for generating embeddings with Khoj
* Update unit tests to fix with new application design
* Allow configure server to be called without regenerating the index; this no longer works because the API for indexing files is not up in time for the server to send a request
* Use state.host and state.port for configuring the URL for the indexer
* On application startup, load in embeddings from configurations files, rather than regenerating the corpus based on file system
This commit is contained in:
sabaimran 2023-08-31 12:55:17 -07:00 committed by GitHub
parent 92cbfef7ab
commit 4854258047
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 990 additions and 508 deletions

View file

@ -11,87 +11,88 @@ import schedule
from fastapi.staticfiles import StaticFiles
# Internal Packages
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from khoj.search_type import image_search, text_search
from khoj.utils import constants, state
from khoj.utils.config import (
ContentIndex,
SearchType,
SearchModels,
ProcessorConfigModel,
ConversationProcessorConfigModel,
)
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig, ConversationProcessorConfig
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.utils.helpers import resolve_absolute_path, merge_dicts
from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ConversationProcessorConfig
from khoj.routers.indexer import configure_content, load_content
logger = logging.getLogger(__name__)
def initialize_server(config: Optional[FullConfig], regenerate: bool, required=False):
def initialize_server(config: Optional[FullConfig], required=False):
if config is None and required:
logger.error(
f"🚨 Exiting as Khoj is not configured.\nConfigure it via http://localhost:42110/config or by editing {state.config_file}."
f"🚨 Exiting as Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config or by editing {state.config_file}."
)
sys.exit(1)
elif config is None:
logger.warning(
f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
f"🚨 Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config, plugins or by editing {state.config_file}."
)
return None
try:
configure_server(config, regenerate)
configure_server(config, init=True)
except Exception as e:
logger.error(f"🚨 Failed to configure server on app load: {e}", exc_info=True)
def configure_server(config: FullConfig, regenerate: bool, search_type: Optional[SearchType] = None):
def configure_server(
config: FullConfig, regenerate: bool = False, search_type: Optional[SearchType] = None, init=False
):
# Update Config
state.config = config
# Initialize Processor from Config
try:
state.config_lock.acquire()
state.processor_config = configure_processor(state.config.processor)
except Exception as e:
logger.error(f"🚨 Failed to configure processor", exc_info=True)
raise e
finally:
state.config_lock.release()
# Initialize Search Models from Config
# Initialize Search Models from Config and initialize content
try:
state.config_lock.acquire()
state.SearchType = configure_search_types(state.config)
state.search_models = configure_search(state.search_models, state.config.search_type)
initialize_content(regenerate, search_type, init)
except Exception as e:
logger.error(f"🚨 Failed to configure search models", exc_info=True)
raise e
finally:
state.config_lock.release()
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False):
# Initialize Content from Config
if state.search_models:
try:
state.config_lock.acquire()
if init:
logger.info("📬 Initializing content index...")
state.content_index = load_content(state.config.content_type, state.content_index, state.search_models)
else:
logger.info("📬 Updating content index...")
all_files = collect_files(state.config.content_type)
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, regenerate, search_type
state.content_index,
state.config.content_type,
all_files,
state.search_models,
regenerate,
search_type,
)
except Exception as e:
logger.error(f"🚨 Failed to index content", exc_info=True)
raise e
finally:
state.config_lock.release()
def configure_routes(app):
@ -99,10 +100,12 @@ def configure_routes(app):
from khoj.routers.api import api
from khoj.routers.api_beta import api_beta
from khoj.routers.web_client import web_client
from khoj.routers.indexer import indexer
app.mount("/static", StaticFiles(directory=constants.web_directory), name="static")
app.include_router(api, prefix="/api")
app.include_router(api_beta, prefix="/api/beta")
app.include_router(indexer, prefix="/indexer")
app.include_router(web_client)
@ -111,15 +114,14 @@ if not state.demo:
@schedule.repeat(schedule.every(61).minutes)
def update_search_index():
try:
state.config_lock.acquire()
logger.info("📬 Updating content index via Scheduler")
all_files = collect_files(state.config.content_type)
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, regenerate=False
state.content_index, state.config.content_type, all_files, state.search_models
)
logger.info("📬 Content index updated via Scheduler")
except Exception as e:
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True)
finally:
state.config_lock.release()
def configure_search_types(config: FullConfig):
@ -154,142 +156,6 @@ def configure_search(search_models: SearchModels, search_config: Optional[Search
return search_models
def configure_content(
content_index: Optional[ContentIndex],
content_config: Optional[ContentConfig],
search_models: SearchModels,
regenerate: bool,
t: Optional[state.SearchType] = None,
) -> Optional[ContentIndex]:
# Run Validation Checks
if content_config is None:
logger.warning("🚨 No Content configuration available.")
return None
if content_index is None:
content_index = ContentIndex()
try:
# Initialize Org Notes Search
if (t == None or t.value == state.SearchType.Org.value) and content_config.org and search_models.text_search:
logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings
content_index.org = text_search.setup(
OrgToJsonl,
content_config.org,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Markdown Search
if (
(t == None or t.value == state.SearchType.Markdown.value)
and content_config.markdown
and search_models.text_search
):
logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings
content_index.markdown = text_search.setup(
MarkdownToJsonl,
content_config.markdown,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize PDF Search
if (t == None or t.value == state.SearchType.Pdf.value) and content_config.pdf and search_models.text_search:
logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings
content_index.pdf = text_search.setup(
PdfToJsonl,
content_config.pdf,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Plaintext Search
if (
(t == None or t.value == state.SearchType.Plaintext.value)
and content_config.plaintext
and search_models.text_search
):
logger.info("📄 Setting up search for plaintext")
# Extract Entries, Generate Plaintext Embeddings
content_index.plaintext = text_search.setup(
PlaintextToJsonl,
content_config.plaintext,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Image Search
if (
(t == None or t.value == state.SearchType.Image.value)
and content_config.image
and search_models.image_search
):
logger.info("🌄 Setting up search for images")
# Extract Entries, Generate Image Embeddings
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
)
if (
(t == None or t.value == state.SearchType.Github.value)
and content_config.github
and search_models.text_search
):
logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings
content_index.github = text_search.setup(
GithubToJsonl,
content_config.github,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Notion Search
if (
(t == None or t.value in state.SearchType.Notion.value)
and content_config.notion
and search_models.text_search
):
logger.info("🔌 Setting up search for notion")
content_index.notion = text_search.setup(
NotionToJsonl,
content_config.notion,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize External Plugin Search
if (t == None or t in state.SearchType) and content_config.plugins and search_models.text_search:
logger.info("🔌 Setting up search for plugins")
content_index.plugins = {}
for plugin_type, plugin_config in content_config.plugins.items():
content_index.plugins[plugin_type] = text_search.setup(
JsonlToJsonl,
plugin_config,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
except Exception as e:
logger.error(f"🚨 Failed to setup search: {e}", exc_info=True)
raise e
# Invalidate Query Cache
state.query_cache = LRU()
return content_index
def configure_processor(
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
):

View file

@ -75,8 +75,8 @@ def run():
poll_task_scheduler()
# Start Server
initialize_server(args.config, args.regenerate, required=False)
configure_routes(app)
initialize_server(args.config, required=False)
start_server(app, host=args.host, port=args.port, socket=args.socket)
else:
from PySide6 import QtWidgets
@ -99,7 +99,7 @@ def run():
tray.show()
# Setup Server
initialize_server(args.config, args.regenerate, required=False)
initialize_server(args.config, required=False)
configure_routes(app)
server = ServerThread(start_server_func=lambda: start_server(app, host=args.host, port=args.port), parent=gui)

View file

@ -37,7 +37,7 @@ class GithubToJsonl(TextToJsonl):
else:
return
def process(self, previous_entries=[]):
def process(self, previous_entries=[], files=None):
if self.config.pat_token is None or self.config.pat_token == "":
logger.error(f"Github PAT token is not set. Skipping github content")
raise ValueError("Github PAT token is not set. Skipping github content")

View file

@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
class JsonlToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries=[]):
def process(self, previous_entries=[], files: dict[str, str] = {}):
# Extract required fields from config
input_jsonl_files, input_jsonl_filter, output_file = (
self.config.input_files,

View file

@ -1,5 +1,4 @@
# Standard Packages
import glob
import logging
import re
import urllib3
@ -8,7 +7,7 @@ from typing import List
# Internal Packages
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
from khoj.utils.helpers import timer
from khoj.utils.constants import empty_escape_sequences
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry, TextContentConfig
@ -23,26 +22,14 @@ class MarkdownToJsonl(TextToJsonl):
self.config = config
# Define Functions
def process(self, previous_entries=[]):
def process(self, previous_entries=[], files=None):
# Extract required fields from config
markdown_files, markdown_file_filter, output_file = (
self.config.input_files,
self.config.input_filter,
self.config.compressed_jsonl,
)
# Input Validation
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
print("At least one of markdown-files or markdown-file-filter is required to be specified")
exit(1)
# Get Markdown Files to Process
markdown_files = MarkdownToJsonl.get_markdown_files(markdown_files, markdown_file_filter)
output_file = self.config.compressed_jsonl
# Extract Entries from specified Markdown files
with timer("Parse entries from Markdown files into dictionaries", logger):
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(
*MarkdownToJsonl.extract_markdown_entries(markdown_files)
*MarkdownToJsonl.extract_markdown_entries(files)
)
# Split entries by max tokens supported by model
@ -65,36 +52,6 @@ class MarkdownToJsonl(TextToJsonl):
return entries_with_ids
@staticmethod
def get_markdown_files(markdown_files=None, markdown_file_filters=None):
"Get Markdown files to process"
absolute_markdown_files, filtered_markdown_files = set(), set()
if markdown_files:
absolute_markdown_files = {get_absolute_path(markdown_file) for markdown_file in markdown_files}
if markdown_file_filters:
filtered_markdown_files = {
filtered_file
for markdown_file_filter in markdown_file_filters
for filtered_file in glob.glob(get_absolute_path(markdown_file_filter), recursive=True)
}
all_markdown_files = sorted(absolute_markdown_files | filtered_markdown_files)
files_with_non_markdown_extensions = {
md_file
for md_file in all_markdown_files
if not md_file.endswith(".md") and not md_file.endswith(".markdown")
}
if any(files_with_non_markdown_extensions):
logger.warning(
f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}"
)
logger.debug(f"Processing files: {all_markdown_files}")
return all_markdown_files
@staticmethod
def extract_markdown_entries(markdown_files):
"Extract entries by heading from specified Markdown files"
@ -104,9 +61,8 @@ class MarkdownToJsonl(TextToJsonl):
entries = []
entry_to_file_map = []
for markdown_file in markdown_files:
with open(markdown_file, "r", encoding="utf8") as f:
try:
markdown_content = f.read()
markdown_content = markdown_files[markdown_file]
entries, entry_to_file_map = MarkdownToJsonl.process_single_markdown_file(
markdown_content, markdown_file, entries, entry_to_file_map
)

View file

@ -80,7 +80,7 @@ class NotionToJsonl(TextToJsonl):
self.body_params = {"page_size": 100}
def process(self, previous_entries=[]):
def process(self, previous_entries=[], files=None):
current_entries = []
# Get all pages

View file

@ -1,13 +1,12 @@
# Standard Packages
import glob
import logging
from pathlib import Path
from typing import Iterable, List
from typing import Iterable, List, Tuple
# Internal Packages
from khoj.processor.org_mode import orgnode
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
from khoj.utils.helpers import timer
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry, TextContentConfig
from khoj.utils import state
@ -22,27 +21,14 @@ class OrgToJsonl(TextToJsonl):
self.config = config
# Define Functions
def process(self, previous_entries: List[Entry] = []):
def process(self, previous_entries: List[Entry] = [], files: dict[str, str] = None) -> List[Tuple[int, Entry]]:
# Extract required fields from config
org_files, org_file_filter, output_file = (
self.config.input_files,
self.config.input_filter,
self.config.compressed_jsonl,
)
output_file = self.config.compressed_jsonl
index_heading_entries = self.config.index_heading_entries
# Input Validation
if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter):
print("At least one of org-files or org-file-filter is required to be specified")
exit(1)
# Get Org Files to Process
with timer("Get org files to process", logger):
org_files = OrgToJsonl.get_org_files(org_files, org_file_filter)
# Extract Entries from specified Org files
with timer("Parse entries from org files into OrgNode objects", logger):
entry_nodes, file_to_entries = self.extract_org_entries(org_files)
entry_nodes, file_to_entries = self.extract_org_entries(files)
with timer("Convert OrgNodes into list of entries", logger):
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
@ -67,36 +53,15 @@ class OrgToJsonl(TextToJsonl):
return entries_with_ids
@staticmethod
def get_org_files(org_files=None, org_file_filters=None):
"Get Org files to process"
absolute_org_files, filtered_org_files = set(), set()
if org_files:
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}
if org_file_filters:
filtered_org_files = {
filtered_file
for org_file_filter in org_file_filters
for filtered_file in glob.glob(get_absolute_path(org_file_filter), recursive=True)
}
all_org_files = sorted(absolute_org_files | filtered_org_files)
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
if any(files_with_non_org_extensions):
logger.warning(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
logger.debug(f"Processing files: {all_org_files}")
return all_org_files
@staticmethod
def extract_org_entries(org_files):
def extract_org_entries(org_files: dict[str, str]):
"Extract entries from specified Org files"
entries = []
entry_to_file_map = []
entry_to_file_map: List[Tuple[orgnode.Orgnode, str]] = []
for org_file in org_files:
filename = org_file
file = org_files[org_file]
try:
org_file_entries = orgnode.makelist_with_filepath(str(org_file))
org_file_entries = orgnode.makelist(file, filename)
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
entries.extend(org_file_entries)
except Exception as e:
@ -109,7 +74,7 @@ class OrgToJsonl(TextToJsonl):
def process_single_org_file(org_content: str, org_file: str, entries: List, entry_to_file_map: List):
# Process single org file. The org parser assumes that the file is a single org file and reads it from a buffer. We'll split the raw conetnt of this file by new line to mimic the same behavior.
try:
org_file_entries = orgnode.makelist(org_content.split("\n"), org_file)
org_file_entries = orgnode.makelist(org_content, org_file)
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
entries.extend(org_file_entries)
return entries, entry_to_file_map

View file

@ -65,6 +65,9 @@ def makelist(file, filename):
"""
ctr = 0
if type(file) == str:
f = file.split("\n")
else:
f = file
todos = {
@ -199,7 +202,8 @@ def makelist(file, filename):
# if we are in a heading
if heading:
# add the line to the bodytext
bodytext += line
bodytext += line.rstrip() + "\n\n" if line.strip() else ""
# bodytext += line + "\n" if line.strip() else "\n"
# else we are in the pre heading portion of the file
elif line.strip():
# so add the line to the introtext

View file

@ -1,7 +1,6 @@
# Standard Packages
import glob
import os
import logging
from pathlib import Path
from typing import List
# External Packages
@ -9,7 +8,7 @@ from langchain.document_loaders import PyPDFLoader
# Internal Packages
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
from khoj.utils.helpers import timer
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry
@ -19,25 +18,13 @@ logger = logging.getLogger(__name__)
class PdfToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries=[]):
def process(self, previous_entries=[], files=dict[str, str]):
# Extract required fields from config
pdf_files, pdf_file_filter, output_file = (
self.config.input_files,
self.config.input_filter,
self.config.compressed_jsonl,
)
# Input Validation
if is_none_or_empty(pdf_files) and is_none_or_empty(pdf_file_filter):
print("At least one of pdf-files or pdf-file-filter is required to be specified")
exit(1)
# Get Pdf Files to Process
pdf_files = PdfToJsonl.get_pdf_files(pdf_files, pdf_file_filter)
output_file = self.config.compressed_jsonl
# Extract Entries from specified Pdf files
with timer("Parse entries from PDF files into dictionaries", logger):
current_entries = PdfToJsonl.convert_pdf_entries_to_maps(*PdfToJsonl.extract_pdf_entries(pdf_files))
current_entries = PdfToJsonl.convert_pdf_entries_to_maps(*PdfToJsonl.extract_pdf_entries(files))
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@ -59,32 +46,6 @@ class PdfToJsonl(TextToJsonl):
return entries_with_ids
@staticmethod
def get_pdf_files(pdf_files=None, pdf_file_filters=None):
"Get PDF files to process"
absolute_pdf_files, filtered_pdf_files = set(), set()
if pdf_files:
absolute_pdf_files = {get_absolute_path(pdf_file) for pdf_file in pdf_files}
if pdf_file_filters:
filtered_pdf_files = {
filtered_file
for pdf_file_filter in pdf_file_filters
for filtered_file in glob.glob(get_absolute_path(pdf_file_filter), recursive=True)
}
all_pdf_files = sorted(absolute_pdf_files | filtered_pdf_files)
files_with_non_pdf_extensions = {pdf_file for pdf_file in all_pdf_files if not pdf_file.endswith(".pdf")}
if any(files_with_non_pdf_extensions):
logger.warning(
f"[Warning] There maybe non pdf-mode files in the input set: {files_with_non_pdf_extensions}"
)
logger.debug(f"Processing files: {all_pdf_files}")
return all_pdf_files
@staticmethod
def extract_pdf_entries(pdf_files):
"""Extract entries by page from specified PDF files"""
@ -93,13 +54,19 @@ class PdfToJsonl(TextToJsonl):
entry_to_location_map = []
for pdf_file in pdf_files:
try:
loader = PyPDFLoader(pdf_file)
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PyPDFLoader expects a file path
with open(f"{pdf_file}.pdf", "wb") as f:
f.write(pdf_files[pdf_file])
loader = PyPDFLoader(f"{pdf_file}.pdf")
pdf_entries_per_file = [page.page_content for page in loader.load()]
entry_to_location_map += zip(pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file))
entries.extend(pdf_entries_per_file)
except Exception as e:
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.")
logger.warning(e)
finally:
if os.path.exists(f"{pdf_file}.pdf"):
os.remove(f"{pdf_file}.pdf")
return entries, dict(entry_to_location_map)
@ -108,9 +75,9 @@ class PdfToJsonl(TextToJsonl):
"Convert each PDF entries into a dictionary"
entries = []
for parsed_entry in parsed_entries:
entry_filename = Path(entry_to_file_map[parsed_entry])
entry_filename = entry_to_file_map[parsed_entry]
# Append base filename to compiled entry for context to model
heading = f"{entry_filename.stem}\n"
heading = f"{entry_filename}\n"
compiled_entry = f"{heading}{parsed_entry}"
entries.append(
Entry(

View file

@ -1,13 +1,12 @@
# Standard Packages
import glob
import logging
from pathlib import Path
from typing import List
from typing import List, Tuple
# Internal Packages
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.utils.helpers import get_absolute_path, timer
from khoj.utils.jsonl import load_jsonl, compress_jsonl_data
from khoj.utils.helpers import timer
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry
@ -16,22 +15,12 @@ logger = logging.getLogger(__name__)
class PlaintextToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries=[]):
# Extract required fields from config
input_files, input_filter, output_file = (
self.config.input_files,
self.config.input_filter,
self.config.compressed_jsonl,
)
# Get Plaintext Input Files to Process
all_input_plaintext_files = PlaintextToJsonl.get_plaintext_files(input_files, input_filter)
def process(self, previous_entries: List[Entry] = [], files: dict[str, str] = None) -> List[Tuple[int, Entry]]:
output_file = self.config.compressed_jsonl
# Extract Entries from specified plaintext files
with timer("Parse entries from plaintext files", logger):
current_entries = PlaintextToJsonl.convert_plaintext_entries_to_maps(
PlaintextToJsonl.extract_plaintext_entries(all_input_plaintext_files)
)
current_entries = PlaintextToJsonl.convert_plaintext_entries_to_maps(files)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@ -53,67 +42,11 @@ class PlaintextToJsonl(TextToJsonl):
return entries_with_ids
@staticmethod
def get_plaintext_files(plaintext_files=None, plaintext_file_filters=None):
"Get all files to process"
absolute_plaintext_files, filtered_plaintext_files = set(), set()
if plaintext_files:
absolute_plaintext_files = {get_absolute_path(jsonl_file) for jsonl_file in plaintext_files}
if plaintext_file_filters:
filtered_plaintext_files = {
filtered_file
for jsonl_file_filter in plaintext_file_filters
for filtered_file in glob.glob(get_absolute_path(jsonl_file_filter), recursive=True)
}
all_target_files = sorted(absolute_plaintext_files | filtered_plaintext_files)
files_with_no_plaintext_extensions = {
target_files for target_files in all_target_files if not PlaintextToJsonl.is_plaintextfile(target_files)
}
if any(files_with_no_plaintext_extensions):
logger.warn(f"Skipping unsupported files from plaintext indexing: {files_with_no_plaintext_extensions}")
all_target_files = list(set(all_target_files) - files_with_no_plaintext_extensions)
logger.debug(f"Processing files: {all_target_files}")
return all_target_files
@staticmethod
def is_plaintextfile(file: str):
"Check if file is plaintext file"
return file.endswith(("txt", "md", "markdown", "org", "mbox", "rst", "html", "htm", "xml"))
@staticmethod
def extract_plaintext_entries(plaintext_files: List[str]):
"Extract entries from specified plaintext files"
entry_to_file_map = []
for plaintext_file in plaintext_files:
with open(plaintext_file, "r") as f:
try:
plaintext_content = f.read()
if plaintext_file.endswith(("html", "htm", "xml")):
plaintext_content = PlaintextToJsonl.extract_html_content(plaintext_content)
entry_to_file_map.append((plaintext_content, plaintext_file))
except Exception as e:
logger.error(f"Error processing file: {plaintext_file} - {e}", exc_info=True)
return dict(entry_to_file_map)
@staticmethod
def extract_html_content(html_content: str):
"Extract content from HTML"
from bs4 import BeautifulSoup
soup = BeautifulSoup(html_content, "html.parser")
return soup.get_text(strip=True, separator="\n")
@staticmethod
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
"Convert each plaintext entries into a dictionary"
entries = []
for entry, file in entry_to_file_map.items():
for file, entry in entry_to_file_map.items():
entries.append(
Entry(
raw=entry,

View file

@ -17,7 +17,7 @@ class TextToJsonl(ABC):
self.config = config
@abstractmethod
def process(self, previous_entries: List[Entry] = []) -> List[Tuple[int, Entry]]:
def process(self, previous_entries: List[Entry] = [], files: dict[str, str] = None) -> List[Tuple[int, Entry]]:
...
@staticmethod

View file

@ -608,7 +608,7 @@ def update(
logger.warning(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
try:
configure_server(state.config, regenerate=force or False, search_type=t)
configure_server(state.config)
except Exception as e:
error_msg = f"🚨 Failed to update server via API: {e}"
logger.error(error_msg, exc_info=True)
@ -765,7 +765,7 @@ async def extract_references_and_questions(
inferred_queries: List[str] = []
if state.content_index is None:
logger.warn(
logger.warning(
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
)
return compiled_references, inferred_queries

285
src/khoj/routers/indexer.py Normal file
View file

@ -0,0 +1,285 @@
# Standard Packages
import logging
from typing import Optional, Union
# External Packages
from fastapi import APIRouter, HTTPException, Header, Request, Body, Response
from pydantic import BaseModel
# Internal Packages
from khoj.utils import state
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from khoj.utils.rawconfig import ContentConfig
from khoj.search_type import text_search, image_search
from khoj.utils.config import SearchModels
from khoj.utils.helpers import LRU
from khoj.utils.rawconfig import (
ContentConfig,
)
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.utils.config import (
ContentIndex,
SearchModels,
)
logger = logging.getLogger(__name__)
indexer = APIRouter()
class IndexBatchRequest(BaseModel):
org: Optional[dict[str, str]]
pdf: Optional[dict[str, str]]
plaintext: Optional[dict[str, str]]
markdown: Optional[dict[str, str]]
@indexer.post("/batch")
async def index_batch(
request: Request,
x_api_key: str = Header(None),
regenerate: bool = False,
search_type: Optional[Union[state.SearchType, str]] = None,
):
if x_api_key != "secret":
raise HTTPException(status_code=401, detail="Invalid API Key")
state.config_lock.acquire()
try:
logger.info(f"Received batch indexing request")
index_batch_request_acc = ""
async for chunk in request.stream():
index_batch_request_acc += chunk.decode()
index_batch_request = IndexBatchRequest.parse_raw(index_batch_request_acc)
logger.info(f"Received batch indexing request size: {len(index_batch_request.dict())}")
# Extract required fields from config
state.content_index = configure_content(
state.content_index,
state.config.content_type,
index_batch_request.dict(),
state.search_models,
regenerate=regenerate,
t=search_type,
)
except Exception as e:
logger.error(f"Failed to process batch indexing request: {e}")
finally:
state.config_lock.release()
return Response(content="OK", status_code=200)
def configure_content(
content_index: Optional[ContentIndex],
content_config: Optional[ContentConfig],
files: Optional[dict[str, dict[str, str]]],
search_models: SearchModels,
regenerate: bool = False,
t: Optional[Union[state.SearchType, str]] = None,
) -> Optional[ContentIndex]:
# Run Validation Checks
if content_config is None:
logger.warning("🚨 No Content configuration available.")
return None
if content_index is None:
content_index = ContentIndex()
if t in [type.value for type in state.SearchType]:
t = state.SearchType(t).value
assert type(t) == str or t == None, f"Invalid search type: {t}"
if files is None:
logger.warning(f"🚨 No files to process for {t} search.")
return None
try:
# Initialize Org Notes Search
if (
(t == None or t == state.SearchType.Org.value)
and content_config.org
and search_models.text_search
and files["org"]
):
logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings
content_index.org = text_search.setup(
OrgToJsonl,
files.get("org"),
content_config.org,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Markdown Search
if (
(t == None or t == state.SearchType.Markdown.value)
and content_config.markdown
and search_models.text_search
and files["markdown"]
):
logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings
content_index.markdown = text_search.setup(
MarkdownToJsonl,
files.get("markdown"),
content_config.markdown,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize PDF Search
if (
(t == None or t == state.SearchType.Pdf.value)
and content_config.pdf
and search_models.text_search
and files["pdf"]
):
logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings
content_index.pdf = text_search.setup(
PdfToJsonl,
files.get("pdf"),
content_config.pdf,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Plaintext Search
if (
(t == None or t == state.SearchType.Plaintext.value)
and content_config.plaintext
and search_models.text_search
and files["plaintext"]
):
logger.info("📄 Setting up search for plaintext")
# Extract Entries, Generate Plaintext Embeddings
content_index.plaintext = text_search.setup(
PlaintextToJsonl,
files.get("plaintext"),
content_config.plaintext,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Image Search
if (t == None or t == state.SearchType.Image.value) and content_config.image and search_models.image_search:
logger.info("🌄 Setting up search for images")
# Extract Entries, Generate Image Embeddings
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
)
if (t == None or t == state.SearchType.Github.value) and content_config.github and search_models.text_search:
logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings
content_index.github = text_search.setup(
GithubToJsonl,
None,
content_config.github,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Notion Search
if (t == None or t in state.SearchType.Notion.value) and content_config.notion and search_models.text_search:
logger.info("🔌 Setting up search for notion")
content_index.notion = text_search.setup(
NotionToJsonl,
None,
content_config.notion,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize External Plugin Search
if (t == None or t in state.SearchType) and content_config.plugins and search_models.text_search:
logger.info("🔌 Setting up search for plugins")
content_index.plugins = {}
for plugin_type, plugin_config in content_config.plugins.items():
content_index.plugins[plugin_type] = text_search.setup(
JsonlToJsonl,
None,
plugin_config,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
except Exception as e:
logger.error(f"🚨 Failed to setup search: {e}", exc_info=True)
raise e
# Invalidate Query Cache
state.query_cache = LRU()
return content_index
def load_content(
content_config: Optional[ContentConfig],
content_index: Optional[ContentIndex],
search_models: SearchModels,
):
logger.info(f"Loading content from existing embeddings...")
if content_config is None:
logger.warning("🚨 No Content configuration available.")
return None
if content_index is None:
content_index = ContentIndex()
if content_config.org:
logger.info("🦄 Loading orgmode notes")
content_index.org = text_search.load(content_config.org, filters=[DateFilter(), WordFilter(), FileFilter()])
if content_config.markdown:
logger.info("💎 Loading markdown notes")
content_index.markdown = text_search.load(
content_config.markdown, filters=[DateFilter(), WordFilter(), FileFilter()]
)
if content_config.pdf:
logger.info("🖨️ Loading pdf")
content_index.pdf = text_search.load(content_config.pdf, filters=[DateFilter(), WordFilter(), FileFilter()])
if content_config.plaintext:
logger.info("📄 Loading plaintext")
content_index.plaintext = text_search.load(
content_config.plaintext, filters=[DateFilter(), WordFilter(), FileFilter()]
)
if content_config.image:
logger.info("🌄 Loading images")
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
if content_config.github:
logger.info("🐙 Loading github")
content_index.github = text_search.load(
content_config.github, filters=[DateFilter(), WordFilter(), FileFilter()]
)
if content_config.notion:
logger.info("🔌 Loading notion")
content_index.notion = text_search.load(
content_config.notion, filters=[DateFilter(), WordFilter(), FileFilter()]
)
if content_config.plugins:
logger.info("🔌 Loading plugins")
content_index.plugins = {}
for plugin_type, plugin_config in content_config.plugins.items():
content_index.plugins[plugin_type] = text_search.load(
plugin_config, filters=[DateFilter(), WordFilter(), FileFilter()]
)
state.query_cache = LRU()
return content_index

View file

@ -104,6 +104,18 @@ def compute_embeddings(
return corpus_embeddings
def load_embeddings(
embeddings_file: Path,
):
"Load pre-computed embeddings from file if exists and update them if required"
if embeddings_file.exists():
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
return util.normalize_embeddings(corpus_embeddings)
return None
async def query(
raw_query: str,
search_model: TextSearchModel,
@ -174,6 +186,7 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
def setup(
text_to_jsonl: Type[TextToJsonl],
files: dict[str, str],
config: TextConfigBase,
bi_encoder: BaseEncoder,
regenerate: bool,
@ -185,7 +198,7 @@ def setup(
previous_entries = []
if config.compressed_jsonl.exists() and not regenerate:
previous_entries = extract_entries(config.compressed_jsonl)
entries_with_indices = text_to_jsonl(config).process(previous_entries)
entries_with_indices = text_to_jsonl(config).process(previous_entries=previous_entries, files=files)
# Extract Updated Entries
entries = extract_entries(config.compressed_jsonl)
@ -205,6 +218,24 @@ def setup(
return TextContent(entries, corpus_embeddings, filters)
def load(
config: TextConfigBase,
filters: List[BaseFilter] = [],
) -> TextContent:
# Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
entries = extract_entries(config.compressed_jsonl)
# Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = load_embeddings(config.embeddings_file)
for filter in filters:
filter.load(entries, regenerate=False)
return TextContent(entries, corpus_embeddings, filters)
def apply_filters(
query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]
) -> Tuple[str, List[Entry], torch.Tensor]:

217
src/khoj/utils/fs_syncer.py Normal file
View file

@ -0,0 +1,217 @@
import logging
import glob
from typing import Optional
from bs4 import BeautifulSoup
from khoj.utils.helpers import get_absolute_path, is_none_or_empty
from khoj.utils.rawconfig import TextContentConfig, ContentConfig
from khoj.utils.config import SearchType
logger = logging.getLogger(__name__)
def collect_files(config: ContentConfig, search_type: Optional[SearchType] = SearchType.All):
files = {}
if search_type == SearchType.All or search_type == SearchType.Org:
files["org"] = get_org_files(config.org) if config.org else {}
if search_type == SearchType.All or search_type == SearchType.Markdown:
files["markdown"] = get_markdown_files(config.markdown) if config.markdown else {}
if search_type == SearchType.All or search_type == SearchType.Plaintext:
files["plaintext"] = get_plaintext_files(config.plaintext) if config.plaintext else {}
if search_type == SearchType.All or search_type == SearchType.Pdf:
files["pdf"] = get_pdf_files(config.pdf) if config.pdf else {}
return files
def get_plaintext_files(config: TextContentConfig) -> dict[str, str]:
def is_plaintextfile(file: str):
"Check if file is plaintext file"
return file.endswith(("txt", "md", "markdown", "org", "mbox", "rst", "html", "htm", "xml"))
def extract_html_content(html_content: str):
"Extract content from HTML"
soup = BeautifulSoup(html_content, "html.parser")
return soup.get_text(strip=True, separator="\n")
# Extract required fields from config
input_files, input_filter = (
config.input_files,
config.input_filter,
)
# Input Validation
if is_none_or_empty(input_files) and is_none_or_empty(input_filter):
logger.debug("At least one of input-files or input-file-filter is required to be specified")
return {}
"Get all files to process"
absolute_plaintext_files, filtered_plaintext_files = set(), set()
if input_files:
absolute_plaintext_files = {get_absolute_path(jsonl_file) for jsonl_file in input_files}
if input_filter:
filtered_plaintext_files = {
filtered_file
for jsonl_file_filter in input_filter
for filtered_file in glob.glob(get_absolute_path(jsonl_file_filter), recursive=True)
}
all_target_files = sorted(absolute_plaintext_files | filtered_plaintext_files)
files_with_no_plaintext_extensions = {
target_files for target_files in all_target_files if not is_plaintextfile(target_files)
}
if any(files_with_no_plaintext_extensions):
logger.warning(f"Skipping unsupported files from plaintext indexing: {files_with_no_plaintext_extensions}")
all_target_files = list(set(all_target_files) - files_with_no_plaintext_extensions)
logger.debug(f"Processing files: {all_target_files}")
filename_to_content_map = {}
for file in all_target_files:
with open(file, "r") as f:
try:
plaintext_content = f.read()
if file.endswith(("html", "htm", "xml")):
plaintext_content = extract_html_content(plaintext_content)
filename_to_content_map[file] = f.read()
except Exception as e:
logger.warning(f"Unable to read file: {file} as plaintext. Skipping file.")
logger.warning(e, exc_info=True)
return filename_to_content_map
def get_org_files(config: TextContentConfig):
# Extract required fields from config
org_files, org_file_filter = (
config.input_files,
config.input_filter,
)
# Input Validation
if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter):
logger.debug("At least one of org-files or org-file-filter is required to be specified")
return {}
"Get Org files to process"
absolute_org_files, filtered_org_files = set(), set()
if org_files:
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}
if org_file_filter:
filtered_org_files = {
filtered_file
for org_file_filter in org_file_filter
for filtered_file in glob.glob(get_absolute_path(org_file_filter), recursive=True)
}
all_org_files = sorted(absolute_org_files | filtered_org_files)
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
if any(files_with_non_org_extensions):
logger.warning(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
logger.debug(f"Processing files: {all_org_files}")
filename_to_content_map = {}
for file in all_org_files:
with open(file, "r") as f:
try:
filename_to_content_map[file] = f.read()
except Exception as e:
logger.warning(f"Unable to read file: {file} as org. Skipping file.")
logger.warning(e, exc_info=True)
return filename_to_content_map
def get_markdown_files(config: TextContentConfig):
# Extract required fields from config
markdown_files, markdown_file_filter = (
config.input_files,
config.input_filter,
)
# Input Validation
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
logger.debug("At least one of markdown-files or markdown-file-filter is required to be specified")
return {}
"Get Markdown files to process"
absolute_markdown_files, filtered_markdown_files = set(), set()
if markdown_files:
absolute_markdown_files = {get_absolute_path(markdown_file) for markdown_file in markdown_files}
if markdown_file_filter:
filtered_markdown_files = {
filtered_file
for markdown_file_filter in markdown_file_filter
for filtered_file in glob.glob(get_absolute_path(markdown_file_filter), recursive=True)
}
all_markdown_files = sorted(absolute_markdown_files | filtered_markdown_files)
files_with_non_markdown_extensions = {
md_file for md_file in all_markdown_files if not md_file.endswith(".md") and not md_file.endswith(".markdown")
}
if any(files_with_non_markdown_extensions):
logger.warning(
f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}"
)
logger.debug(f"Processing files: {all_markdown_files}")
filename_to_content_map = {}
for file in all_markdown_files:
with open(file, "r") as f:
try:
filename_to_content_map[file] = f.read()
except Exception as e:
logger.warning(f"Unable to read file: {file} as markdown. Skipping file.")
logger.warning(e, exc_info=True)
return filename_to_content_map
def get_pdf_files(config: TextContentConfig):
# Extract required fields from config
pdf_files, pdf_file_filter = (
config.input_files,
config.input_filter,
)
# Input Validation
if is_none_or_empty(pdf_files) and is_none_or_empty(pdf_file_filter):
logger.debug("At least one of pdf-files or pdf-file-filter is required to be specified")
return {}
"Get PDF files to process"
absolute_pdf_files, filtered_pdf_files = set(), set()
if pdf_files:
absolute_pdf_files = {get_absolute_path(pdf_file) for pdf_file in pdf_files}
if pdf_file_filter:
filtered_pdf_files = {
filtered_file
for pdf_file_filter in pdf_file_filter
for filtered_file in glob.glob(get_absolute_path(pdf_file_filter), recursive=True)
}
all_pdf_files = sorted(absolute_pdf_files | filtered_pdf_files)
files_with_non_pdf_extensions = {pdf_file for pdf_file in all_pdf_files if not pdf_file.endswith(".pdf")}
if any(files_with_non_pdf_extensions):
logger.warning(f"[Warning] There maybe non pdf-mode files in the input set: {files_with_non_pdf_extensions}")
logger.debug(f"Processing files: {all_pdf_files}")
filename_to_content_map = {}
for file in all_pdf_files:
with open(file, "rb") as f:
try:
filename_to_content_map[file] = f.read()
except Exception as e:
logger.warning(f"Unable to read file: {file} as PDF. Skipping file.")
logger.warning(e, exc_info=True)
return filename_to_content_map

View file

@ -9,6 +9,7 @@ import pytest
from khoj.main import app
from khoj.configure import configure_processor, configure_routes, configure_search_types
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from khoj.search_type import image_search, text_search
from khoj.utils.config import SearchModels
from khoj.utils.helpers import resolve_absolute_path
@ -97,7 +98,12 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
OrgToJsonl,
get_sample_data("org"),
content_config.org,
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
content_config.plugins = {
@ -109,6 +115,20 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
)
}
if os.getenv("GITHUB_PAT_TOKEN"):
content_config.github = GithubContentConfig(
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
repos=[
GithubRepoConfig(
owner="khoj-ai",
name="lantern",
branch="master",
)
],
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
)
content_config.plaintext = TextContentConfig(
input_files=None,
input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"],
@ -132,6 +152,7 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(
JsonlToJsonl,
None,
content_config.plugins["plugin1"],
search_models.text_search.bi_encoder,
regenerate=False,
@ -203,6 +224,7 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.content_index.markdown = text_search.setup(
MarkdownToJsonl,
get_sample_data("markdown"),
md_content_config.markdown,
state.search_models.text_search.bi_encoder,
regenerate=False,
@ -226,11 +248,22 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image)
state.content_index.org = text_search.setup(
OrgToJsonl, content_config.org, state.search_models.text_search.bi_encoder, regenerate=False
OrgToJsonl,
get_sample_data("org"),
content_config.org,
state.search_models.text_search.bi_encoder,
regenerate=False,
)
state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False
)
state.content_index.plaintext = text_search.setup(
PlaintextToJsonl,
get_sample_data("plaintext"),
content_config.plaintext,
state.search_models.text_search.bi_encoder,
regenerate=False,
)
state.processor_config = configure_processor(processor_config)
@ -250,8 +283,21 @@ def client_offline_chat(
# Index Markdown Content for Search
filters = [DateFilter(), WordFilter(), FileFilter()]
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image)
state.content_index.org = text_search.setup(
OrgToJsonl,
get_sample_data("org"),
content_config.org,
state.search_models.text_search.bi_encoder,
regenerate=False,
)
state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False
)
state.content_index.markdown = text_search.setup(
MarkdownToJsonl,
get_sample_data("markdown"),
md_content_config.markdown,
state.search_models.text_search.bi_encoder,
regenerate=False,
@ -284,3 +330,69 @@ def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: P
new_org_config.input_files = [f"{new_org_file}"]
new_org_config.input_filter = None
return new_org_config
@pytest.fixture(scope="function")
def sample_org_data():
return get_sample_data("org")
def get_sample_data(type):
sample_data = {
"org": {
"readme.org": """
* Khoj
/Allow natural language search on user content like notes, images using transformer based models/
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
** Dependencies
- Python3
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
** Install
#+begin_src shell
git clone https://github.com/khoj-ai/khoj && cd khoj
conda env create -f environment.yml
conda activate khoj
#+end_src"""
},
"markdown": {
"readme.markdown": """
# Khoj
Allow natural language search on user content like notes, images using transformer based models
All data is processed locally. User can interface with khoj app via [Emacs](./interface/emacs/khoj.el), API or Commandline
## Dependencies
- Python3
- [Miniconda](https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links)
## Install
```shell
git clone
conda env create -f environment.yml
conda activate khoj
```
"""
},
"plaintext": {
"readme.txt": """
Khoj
Allow natural language search on user content like notes, images using transformer based models
All data is processed locally. User can interface with khoj app via Emacs, API or Commandline
Dependencies
- Python3
- Miniconda
Install
git clone
conda env create -f environment.yml
conda activate khoj
"""
},
}
return sample_data[type]

View file

@ -11,7 +11,6 @@ from fastapi.testclient import TestClient
from khoj.main import app
from khoj.configure import configure_routes, configure_search_types
from khoj.utils import state
from khoj.utils.config import SearchModels
from khoj.utils.state import search_models, content_index, config
from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
@ -51,28 +50,6 @@ def test_update_with_invalid_content_type(client):
assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
def test_update_with_valid_content_type(client):
for content_type in ["all", "org", "markdown", "image", "pdf", "notion", "plugin1"]:
# Act
response = client.get(f"/api/update?t={content_type}")
# Assert
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
# ----------------------------------------------------------------------------------------------------
def test_update_with_github_fails_without_pat(client):
# Act
response = client.get(f"/api/update?t=github")
# Assert
assert response.status_code == 500, f"Returned status: {response.status_code} for content type: github"
assert (
response.json()["detail"]
== "🚨 Failed to update server via API: Github PAT token is not set. Skipping github content"
)
# ----------------------------------------------------------------------------------------------------
def test_regenerate_with_invalid_content_type(client):
# Act
@ -82,11 +59,29 @@ def test_regenerate_with_invalid_content_type(client):
assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
def test_index_batch(client):
# Arrange
request_body = get_sample_files_data()
headers = {"x-api-key": "secret"}
# Act
response = client.post("/indexer/batch", json=request_body, headers=headers)
# Assert
assert response.status_code == 200
# ----------------------------------------------------------------------------------------------------
def test_regenerate_with_valid_content_type(client):
for content_type in ["all", "org", "markdown", "image", "pdf", "notion", "plugin1"]:
# Arrange
request_body = get_sample_files_data()
headers = {"x-api-key": "secret"}
# Act
response = client.get(f"/api/update?force=true&t={content_type}")
response = client.post(f"/indexer/batch?search_type={content_type}", json=request_body, headers=headers)
# Assert
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
@ -96,12 +91,15 @@ def test_regenerate_with_github_fails_without_pat(client):
# Act
response = client.get(f"/api/update?force=true&t=github")
# Arrange
request_body = get_sample_files_data()
headers = {"x-api-key": "secret"}
# Act
response = client.post(f"/indexer/batch?search_type=github", json=request_body, headers=headers)
# Assert
assert response.status_code == 500, f"Returned status: {response.status_code} for content type: github"
assert (
response.json()["detail"]
== "🚨 Failed to update server via API: Github PAT token is not set. Skipping github content"
)
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: github"
# ----------------------------------------------------------------------------------------------------
@ -111,7 +109,7 @@ def test_get_configured_types_via_api(client):
# Assert
assert response.status_code == 200
assert response.json() == ["all", "org", "image", "plugin1"]
assert response.json() == ["all", "org", "image", "plaintext", "plugin1"]
# ----------------------------------------------------------------------------------------------------
@ -194,11 +192,11 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
# ----------------------------------------------------------------------------------------------------
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig):
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data):
# Arrange
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
user_query = quote("How to git install application?")
@ -213,12 +211,19 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_only_filters(client, content_config: ContentConfig, search_config: SearchConfig):
def test_notes_search_with_only_filters(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
):
# Arrange
filters = [WordFilter(), FileFilter()]
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
OrgToJsonl,
sample_org_data,
content_config.org,
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
user_query = quote('+"Emacs" file:"*.org"')
@ -233,12 +238,14 @@ def test_notes_search_with_only_filters(client, content_config: ContentConfig, s
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_include_filter(client, content_config: ContentConfig, search_config: SearchConfig):
def test_notes_search_with_include_filter(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
):
# Arrange
filters = [WordFilter()]
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search, regenerate=False, filters=filters
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search, regenerate=False, filters=filters
)
user_query = quote('How to git install application? +"Emacs"')
@ -253,12 +260,19 @@ def test_notes_search_with_include_filter(client, content_config: ContentConfig,
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_exclude_filter(client, content_config: ContentConfig, search_config: SearchConfig):
def test_notes_search_with_exclude_filter(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
):
# Arrange
filters = [WordFilter()]
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
OrgToJsonl,
sample_org_data,
content_config.org,
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
user_query = quote('How to git install application? -"clone"')
@ -270,3 +284,28 @@ def test_notes_search_with_exclude_filter(client, content_config: ContentConfig,
# assert actual_data does not contains word "clone"
search_result = response.json()[0]["entry"]
assert "clone" not in search_result
def get_sample_files_data():
return {
"org": {
"path/to/filename.org": "* practicing piano",
"path/to/filename1.org": "** top 3 reasons why I moved to SF",
"path/to/filename2.org": "* how to build a search engine",
},
"pdf": {
"path/to/filename.pdf": "Moore's law does not apply to consumer hardware",
"path/to/filename1.pdf": "The sun is a ball of helium",
"path/to/filename2.pdf": "Effect of sunshine on baseline human happiness",
},
"plaintext": {
"path/to/filename.txt": "data,column,value",
"path/to/filename1.txt": "<html>my first web page</html>",
"path/to/filename2.txt": "2021-02-02 Journal Entry",
},
"markdown": {
"path/to/filename.md": "# Notes from client call",
"path/to/filename1.md": "## Studying anthropological records from the Fatimid caliphate",
"path/to/filename2.md": "**Understanding science through the lens of art**",
},
}

View file

@ -1,9 +1,12 @@
# Standard Packages
import json
from pathlib import Path
import os
# Internal Packages
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.utils.fs_syncer import get_markdown_files
from khoj.utils.rawconfig import TextContentConfig
def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
@ -13,12 +16,14 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
- Bullet point 1
- Bullet point 2
"""
markdownfile = create_file(tmp_path, entry)
expected_heading = "# " + markdownfile.stem
data = {
f"{tmp_path}": entry,
}
expected_heading = f"# {tmp_path.stem}"
# Act
# Extract Entries from specified Markdown files
entry_nodes, file_to_entries = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile])
entry_nodes, file_to_entries = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
# Process Each Entry from All Notes Files
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
@ -41,11 +46,13 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
\t\r
Body Line 1
"""
markdownfile = create_file(tmp_path, entry)
data = {
f"{tmp_path}": entry,
}
# Act
# Extract Entries from specified Markdown files
entries, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile])
entries, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
# Process Each Entry from All Notes Files
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
@ -68,11 +75,13 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
\t\r
Heading 2 Body Line 2
"""
markdownfile = create_file(tmp_path, entry)
data = {
f"{tmp_path}": entry,
}
# Act
# Extract Entries from specified Markdown files
entry_strings, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile])
entry_strings, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
entries = MarkdownToJsonl.convert_markdown_entries_to_maps(entry_strings, entry_to_file_map)
# Process Each Entry from All Notes Files
@ -82,7 +91,7 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
# Assert
assert len(jsonl_data) == 2
# Ensure entry compiled strings include the markdown files they originate from
assert all([markdownfile.stem in entry.compiled for entry in entries])
assert all([tmp_path.stem in entry.compiled for entry in entries])
def test_get_markdown_files(tmp_path):
@ -99,18 +108,27 @@ def test_get_markdown_files(tmp_path):
create_file(tmp_path, filename="not-included-markdown.md")
create_file(tmp_path, filename="not-included-text.txt")
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
expected_files = set(
[os.path.join(tmp_path, file.name) for file in [group1_file1, group1_file2, group2_file1, group2_file2, file1]]
)
# Setup input-files, input-filters
input_files = [tmp_path / "notes.md"]
input_filter = [tmp_path / "group1*.md", tmp_path / "group2*.markdown"]
markdown_config = TextContentConfig(
input_files=input_files,
input_filter=[str(filter) for filter in input_filter],
compressed_jsonl=tmp_path / "test.jsonl",
embeddings_file=tmp_path / "test_embeddings.jsonl",
)
# Act
extracted_org_files = MarkdownToJsonl.get_markdown_files(input_files, input_filter)
extracted_org_files = get_markdown_files(markdown_config)
# Assert
assert len(extracted_org_files) == 5
assert extracted_org_files == expected_files
assert set(extracted_org_files.keys()) == expected_files
def test_extract_entries_with_different_level_headings(tmp_path):
@ -120,11 +138,13 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Heading 1
## Heading 2
"""
markdownfile = create_file(tmp_path, entry)
data = {
f"{tmp_path}": entry,
}
# Act
# Extract Entries from specified Markdown files
entries, _ = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile])
entries, _ = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
# Assert
assert len(entries) == 2

View file

@ -1,11 +1,14 @@
# Standard Packages
import json
import os
# Internal Packages
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.utils.helpers import is_none_or_empty
from khoj.utils.rawconfig import Entry
from khoj.utils.fs_syncer import get_org_files
from khoj.utils.rawconfig import TextContentConfig
def test_configure_heading_entry_to_jsonl(tmp_path):
@ -18,14 +21,17 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
:END:
\t \r
"""
orgfile = create_file(tmp_path, entry)
data = {
f"{tmp_path}": entry,
}
for index_heading_entries in [True, False]:
# Act
# Extract entries into jsonl from specified Org files
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
OrgToJsonl.convert_org_nodes_to_entries(
*OrgToJsonl.extract_org_entries(org_files=[orgfile]), index_heading_entries=index_heading_entries
*OrgToJsonl.extract_org_entries(org_files=data), index_heading_entries=index_heading_entries
)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@ -46,12 +52,14 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
\t\r
Body Line
"""
orgfile = create_file(tmp_path, entry)
expected_heading = f"* {orgfile.stem}\n** Heading"
data = {
f"{tmp_path}": entry,
}
expected_heading = f"* {tmp_path.stem}\n** Heading"
# Act
# Extract Entries from specified Org files
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile])
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=data)
# Split each entry from specified Org files by max words
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
@ -95,11 +103,13 @@ def test_entry_with_body_to_jsonl(tmp_path):
\t\r
Body Line 1
"""
orgfile = create_file(tmp_path, entry)
data = {
f"{tmp_path}": entry,
}
# Act
# Extract Entries from specified Org files
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile])
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=data)
# Process Each Entry from All Notes Files
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
@ -120,11 +130,13 @@ Intro text
* Entry Heading
entry body
"""
orgfile = create_file(tmp_path, entry)
data = {
f"{tmp_path}": entry,
}
# Act
# Extract Entries from specified Org files
entry_nodes, file_to_entries = OrgToJsonl.extract_org_entries(org_files=[orgfile])
entry_nodes, file_to_entries = OrgToJsonl.extract_org_entries(org_files=data)
# Process Each Entry from All Notes Files
entries = OrgToJsonl.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
@ -142,11 +154,13 @@ def test_file_with_no_headings_to_jsonl(tmp_path):
- Bullet point 1
- Bullet point 2
"""
orgfile = create_file(tmp_path, entry)
data = {
f"{tmp_path}": entry,
}
# Act
# Extract Entries from specified Org files
entry_nodes, file_to_entries = OrgToJsonl.extract_org_entries(org_files=[orgfile])
entry_nodes, file_to_entries = OrgToJsonl.extract_org_entries(org_files=data)
# Process Each Entry from All Notes Files
entries = OrgToJsonl.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
@ -171,18 +185,30 @@ def test_get_org_files(tmp_path):
create_file(tmp_path, filename="orgfile2.org")
create_file(tmp_path, filename="text1.txt")
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1]))
expected_files = set(
[
os.path.join(tmp_path, file.name)
for file in [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1]
]
)
# Setup input-files, input-filters
input_files = [tmp_path / "orgfile1.org"]
input_filter = [tmp_path / "group1*.org", tmp_path / "group2*.org"]
org_config = TextContentConfig(
input_files=input_files,
input_filter=[str(filter) for filter in input_filter],
compressed_jsonl=tmp_path / "test.jsonl",
embeddings_file=tmp_path / "test_embeddings.jsonl",
)
# Act
extracted_org_files = OrgToJsonl.get_org_files(input_files, input_filter)
extracted_org_files = get_org_files(org_config)
# Assert
assert len(extracted_org_files) == 5
assert extracted_org_files == expected_files
assert set(extracted_org_files.keys()) == expected_files
def test_extract_entries_with_different_level_headings(tmp_path):
@ -192,11 +218,13 @@ def test_extract_entries_with_different_level_headings(tmp_path):
* Heading 1
** Heading 2
"""
orgfile = create_file(tmp_path, entry)
data = {
f"{tmp_path}": entry,
}
# Act
# Extract Entries from specified Org files
entries, _ = OrgToJsonl.extract_org_entries(org_files=[orgfile])
entries, _ = OrgToJsonl.extract_org_entries(org_files=data)
# Assert
assert len(entries) == 2

View file

@ -44,7 +44,7 @@ Body Line 1"""
assert len(entries) == 1
assert entries[0].heading == "Heading"
assert entries[0].tags == list()
assert entries[0].body == "Body Line 1"
assert entries[0].body == "Body Line 1\n\n"
assert entries[0].priority == ""
assert entries[0].Property("ID") == ""
assert entries[0].closed == ""
@ -78,7 +78,7 @@ Body Line 2"""
assert entries[0].heading == "Heading"
assert entries[0].todo == "DONE"
assert entries[0].tags == ["Tag1", "TAG2", "tag3"]
assert entries[0].body == "- Clocked Log 1\nBody Line 1\nBody Line 2"
assert entries[0].body == "- Clocked Log 1\n\nBody Line 1\n\nBody Line 2\n\n"
assert entries[0].priority == "A"
assert entries[0].Property("ID") == "id:123-456-789-4234-1231"
assert entries[0].closed == datetime.date(1984, 4, 1)
@ -205,7 +205,7 @@ Body 2
assert entry.heading == f"Heading{index+1}"
assert entry.todo == "FAILED" if index == 0 else "CANCELLED"
assert entry.tags == [f"tag{index+1}"]
assert entry.body == f"- Clocked Log {index+1}\nBody {index+1}\n\n"
assert entry.body == f"- Clocked Log {index+1}\n\nBody {index+1}\n\n"
assert entry.priority == "A"
assert entry.Property("ID") == f"id:123-456-789-4234-000{index+1}"
assert entry.closed == datetime.date(1984, 4, index + 1)
@ -305,7 +305,7 @@ entry body
assert entries[0].heading == "Title"
assert entries[0].body == "intro body\n"
assert entries[1].heading == "Entry Heading"
assert entries[1].body == "entry body\n"
assert entries[1].body == "entry body\n\n"
# ----------------------------------------------------------------------------------------------------
@ -327,7 +327,7 @@ entry body
assert entries[0].heading == "Title1 Title2"
assert entries[0].body == "intro body\n"
assert entries[1].heading == "Entry Heading"
assert entries[1].body == "entry body\n"
assert entries[1].body == "entry body\n\n"
# Helper Functions

View file

@ -1,15 +1,24 @@
# Standard Packages
import json
import os
# Internal Packages
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
from khoj.utils.fs_syncer import get_pdf_files
from khoj.utils.rawconfig import TextContentConfig
def test_single_page_pdf_to_jsonl():
"Convert single page PDF file to jsonl."
# Act
# Extract Entries from specified Pdf files
entries, entry_to_file_map = PdfToJsonl.extract_pdf_entries(pdf_files=["tests/data/pdf/singlepage.pdf"])
# Read singlepage.pdf into memory as bytes
with open("tests/data/pdf/singlepage.pdf", "rb") as f:
pdf_bytes = f.read()
data = {"tests/data/pdf/singlepage.pdf": pdf_bytes}
entries, entry_to_file_map = PdfToJsonl.extract_pdf_entries(pdf_files=data)
# Process Each Entry from All Pdf Files
jsonl_string = PdfToJsonl.convert_pdf_maps_to_jsonl(
@ -25,7 +34,11 @@ def test_multi_page_pdf_to_jsonl():
"Convert multiple pages from single PDF file to jsonl."
# Act
# Extract Entries from specified Pdf files
entries, entry_to_file_map = PdfToJsonl.extract_pdf_entries(pdf_files=["tests/data/pdf/multipage.pdf"])
with open("tests/data/pdf/multipage.pdf", "rb") as f:
pdf_bytes = f.read()
data = {"tests/data/pdf/multipage.pdf": pdf_bytes}
entries, entry_to_file_map = PdfToJsonl.extract_pdf_entries(pdf_files=data)
# Process Each Entry from All Pdf Files
jsonl_string = PdfToJsonl.convert_pdf_maps_to_jsonl(
@ -51,18 +64,27 @@ def test_get_pdf_files(tmp_path):
create_file(tmp_path, filename="not-included-document.pdf")
create_file(tmp_path, filename="not-included-text.txt")
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
expected_files = set(
[os.path.join(tmp_path, file.name) for file in [group1_file1, group1_file2, group2_file1, group2_file2, file1]]
)
# Setup input-files, input-filters
input_files = [tmp_path / "document.pdf"]
input_filter = [tmp_path / "group1*.pdf", tmp_path / "group2*.pdf"]
pdf_config = TextContentConfig(
input_files=input_files,
input_filter=[str(path) for path in input_filter],
compressed_jsonl=tmp_path / "test.jsonl",
embeddings_file=tmp_path / "test_embeddings.jsonl",
)
# Act
extracted_pdf_files = PdfToJsonl.get_pdf_files(input_files, input_filter)
extracted_pdf_files = get_pdf_files(pdf_config)
# Assert
assert len(extracted_pdf_files) == 5
assert extracted_pdf_files == expected_files
assert set(extracted_pdf_files.keys()) == expected_files
# Helper Functions

View file

@ -1,8 +1,11 @@
# Standard Packages
import json
import os
from pathlib import Path
# Internal Packages
from khoj.utils.fs_syncer import get_plaintext_files
from khoj.utils.rawconfig import TextContentConfig
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
@ -18,9 +21,12 @@ def test_plaintext_file(tmp_path):
# Act
# Extract Entries from specified plaintext files
file_to_entries = PlaintextToJsonl.extract_plaintext_entries(plaintext_files=[str(plaintextfile)])
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(file_to_entries)
data = {
f"{plaintextfile}": entry,
}
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(entry_to_file_map=data)
# Convert each entry.file to absolute path to make them JSON serializable
for map in maps:
@ -59,33 +65,40 @@ def test_get_plaintext_files(tmp_path):
create_file(tmp_path, filename="not-included-markdown.md")
create_file(tmp_path, filename="not-included-text.txt")
expected_files = sorted(
map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1, group2_file3, group2_file4])
expected_files = set(
[
os.path.join(tmp_path, file.name)
for file in [group1_file1, group1_file2, group2_file1, group2_file2, group2_file3, group2_file4, file1]
]
)
# Setup input-files, input-filters
input_files = [tmp_path / "notes.txt"]
input_filter = [tmp_path / "group1*.md", tmp_path / "group2*.*"]
plaintext_config = TextContentConfig(
input_files=input_files,
input_filter=[str(filter) for filter in input_filter],
compressed_jsonl=tmp_path / "test.jsonl",
embeddings_file=tmp_path / "test_embeddings.jsonl",
)
# Act
extracted_plaintext_files = PlaintextToJsonl.get_plaintext_files(input_files, input_filter)
extracted_plaintext_files = get_plaintext_files(plaintext_config)
# Assert
assert len(extracted_plaintext_files) == 7
assert set(extracted_plaintext_files) == set(expected_files)
assert set(extracted_plaintext_files.keys()) == set(expected_files)
def test_parse_html_plaintext_file(content_config):
"Ensure HTML files are parsed correctly"
# Arrange
# Setup input-files, input-filters
input_files = content_config.plaintext.input_files
input_filter = content_config.plaintext.input_filter
extracted_plaintext_files = get_plaintext_files(content_config.plaintext)
# Act
extracted_plaintext_files = PlaintextToJsonl.get_plaintext_files(input_files, input_filter)
file_to_entries = PlaintextToJsonl.extract_plaintext_entries(extracted_plaintext_files)
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(file_to_entries)
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(extracted_plaintext_files)
# Assert
assert len(maps) == 1

View file

@ -13,6 +13,7 @@ from khoj.search_type import text_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.utils.fs_syncer import get_org_files
# Test
@ -27,26 +28,30 @@ def test_text_search_setup_with_missing_file_raises_error(
# Act
# Generate notes embeddings during asymmetric setup
with pytest.raises(ValueError, match=r"^No valid entries found in specified files:*"):
text_search.setup(OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
with pytest.raises(FileNotFoundError):
data = get_org_files(org_config_with_only_new_file)
# ----------------------------------------------------------------------------------------------------
def test_text_search_setup_with_empty_file_raises_error(
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
):
# Arrange
data = get_org_files(org_config_with_only_new_file)
# Act
# Generate notes embeddings during asymmetric setup
with pytest.raises(ValueError, match=r"^No valid entries found*"):
text_search.setup(OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
text_search.setup(OrgToJsonl, data, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
# ----------------------------------------------------------------------------------------------------
def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels):
# Arrange
data = get_org_files(content_config.org)
# Act
# Regenerate notes embeddings during asymmetric setup
notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
# Assert
@ -59,14 +64,16 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, sea
# Arrange
caplog.set_level(logging.INFO, logger="khoj")
data = get_org_files(content_config.org)
# Act
# Generate initial notes embeddings during asymmetric setup
text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
initial_logs = caplog.text
caplog.clear() # Clear logs
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
final_logs = caplog.text
# Assert
@ -78,9 +85,11 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, sea
@pytest.mark.anyio
async def test_text_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
data = get_org_files(content_config.org)
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
query = "How to git install application?"
@ -108,10 +117,12 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
for index in range(max_tokens + 1):
f.write(f"{index} ")
data = get_org_files(org_config_with_only_new_file)
# Act
# reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
@ -125,8 +136,9 @@ def test_regenerate_index_with_new_entry(
content_config: ContentConfig, search_models: SearchModels, new_org_file: Path
):
# Arrange
data = get_org_files(content_config.org)
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
assert len(initial_notes_model.entries) == 10
@ -137,10 +149,12 @@ def test_regenerate_index_with_new_entry(
with open(new_org_file, "w") as f:
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
data = get_org_files(content_config.org)
# Act
# regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
# Assert
@ -169,15 +183,19 @@ def test_update_index_with_duplicate_entries_in_stable_order(
with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}{new_entry}")
data = get_org_files(org_config_with_only_new_file)
# Act
# load embeddings, entries, notes model after adding new org-mode file
initial_index = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
)
data = get_org_files(org_config_with_only_new_file)
# update embeddings, entries, notes model after adding new org-mode file
updated_index = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
@ -200,19 +218,22 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}{new_entry} -- Tatooine")
data = get_org_files(org_config_with_only_new_file)
# load embeddings, entries, notes model after adding new org file with 2 entries
initial_index = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
)
# update embeddings, entries, notes model after removing an entry from the org file
with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}")
data = get_org_files(org_config_with_only_new_file)
# Act
updated_index = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
@ -229,8 +250,9 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
# ----------------------------------------------------------------------------------------------------
def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
# Arrange
data = get_org_files(content_config.org)
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False
)
# append org-mode entry to first org input file in config
@ -238,11 +260,13 @@ def test_update_index_with_new_entry(content_config: ContentConfig, search_model
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
f.write(new_entry)
data = get_org_files(content_config.org)
# Act
# update embeddings, entries with the newly added note
content_config.org.input_files = [f"{new_org_file}"]
final_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
)
# Assert