mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Move to a push-first model for retrieving embeddings from local files (#457)
* Initial version - setup a file-push architecture for generating embeddings with Khoj * Update unit tests to fix with new application design * Allow configure server to be called without regenerating the index; this no longer works because the API for indexing files is not up in time for the server to send a request * Use state.host and state.port for configuring the URL for the indexer * On application startup, load in embeddings from configurations files, rather than regenerating the corpus based on file system
This commit is contained in:
parent
92cbfef7ab
commit
4854258047
23 changed files with 990 additions and 508 deletions
|
@ -11,87 +11,88 @@ import schedule
|
|||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
|
||||
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
|
||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import (
|
||||
ContentIndex,
|
||||
SearchType,
|
||||
SearchModels,
|
||||
ProcessorConfigModel,
|
||||
ConversationProcessorConfigModel,
|
||||
)
|
||||
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
|
||||
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig, ConversationProcessorConfig
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.utils.helpers import resolve_absolute_path, merge_dicts
|
||||
from khoj.utils.fs_syncer import collect_files
|
||||
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ConversationProcessorConfig
|
||||
from khoj.routers.indexer import configure_content, load_content
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_server(config: Optional[FullConfig], regenerate: bool, required=False):
|
||||
def initialize_server(config: Optional[FullConfig], required=False):
|
||||
if config is None and required:
|
||||
logger.error(
|
||||
f"🚨 Exiting as Khoj is not configured.\nConfigure it via http://localhost:42110/config or by editing {state.config_file}."
|
||||
f"🚨 Exiting as Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config or by editing {state.config_file}."
|
||||
)
|
||||
sys.exit(1)
|
||||
elif config is None:
|
||||
logger.warning(
|
||||
f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
|
||||
f"🚨 Khoj is not configured.\nConfigure it via http://{state.host}:{state.port}/config, plugins or by editing {state.config_file}."
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
configure_server(config, regenerate)
|
||||
configure_server(config, init=True)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to configure server on app load: {e}", exc_info=True)
|
||||
|
||||
|
||||
def configure_server(config: FullConfig, regenerate: bool, search_type: Optional[SearchType] = None):
|
||||
def configure_server(
|
||||
config: FullConfig, regenerate: bool = False, search_type: Optional[SearchType] = None, init=False
|
||||
):
|
||||
# Update Config
|
||||
state.config = config
|
||||
|
||||
# Initialize Processor from Config
|
||||
try:
|
||||
state.config_lock.acquire()
|
||||
state.processor_config = configure_processor(state.config.processor)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to configure processor", exc_info=True)
|
||||
raise e
|
||||
finally:
|
||||
state.config_lock.release()
|
||||
|
||||
# Initialize Search Models from Config
|
||||
# Initialize Search Models from Config and initialize content
|
||||
try:
|
||||
state.config_lock.acquire()
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
initialize_content(regenerate, search_type, init)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to configure search models", exc_info=True)
|
||||
raise e
|
||||
finally:
|
||||
state.config_lock.release()
|
||||
|
||||
|
||||
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False):
|
||||
# Initialize Content from Config
|
||||
if state.search_models:
|
||||
try:
|
||||
state.config_lock.acquire()
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, state.search_models, regenerate, search_type
|
||||
)
|
||||
if init:
|
||||
logger.info("📬 Initializing content index...")
|
||||
state.content_index = load_content(state.config.content_type, state.content_index, state.search_models)
|
||||
else:
|
||||
logger.info("📬 Updating content index...")
|
||||
all_files = collect_files(state.config.content_type)
|
||||
state.content_index = configure_content(
|
||||
state.content_index,
|
||||
state.config.content_type,
|
||||
all_files,
|
||||
state.search_models,
|
||||
regenerate,
|
||||
search_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to index content", exc_info=True)
|
||||
raise e
|
||||
finally:
|
||||
state.config_lock.release()
|
||||
|
||||
|
||||
def configure_routes(app):
|
||||
|
@ -99,10 +100,12 @@ def configure_routes(app):
|
|||
from khoj.routers.api import api
|
||||
from khoj.routers.api_beta import api_beta
|
||||
from khoj.routers.web_client import web_client
|
||||
from khoj.routers.indexer import indexer
|
||||
|
||||
app.mount("/static", StaticFiles(directory=constants.web_directory), name="static")
|
||||
app.include_router(api, prefix="/api")
|
||||
app.include_router(api_beta, prefix="/api/beta")
|
||||
app.include_router(indexer, prefix="/indexer")
|
||||
app.include_router(web_client)
|
||||
|
||||
|
||||
|
@ -111,15 +114,14 @@ if not state.demo:
|
|||
@schedule.repeat(schedule.every(61).minutes)
|
||||
def update_search_index():
|
||||
try:
|
||||
state.config_lock.acquire()
|
||||
logger.info("📬 Updating content index via Scheduler")
|
||||
all_files = collect_files(state.config.content_type)
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, state.search_models, regenerate=False
|
||||
state.content_index, state.config.content_type, all_files, state.search_models
|
||||
)
|
||||
logger.info("📬 Content index updated via Scheduler")
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True)
|
||||
finally:
|
||||
state.config_lock.release()
|
||||
|
||||
|
||||
def configure_search_types(config: FullConfig):
|
||||
|
@ -154,142 +156,6 @@ def configure_search(search_models: SearchModels, search_config: Optional[Search
|
|||
return search_models
|
||||
|
||||
|
||||
def configure_content(
|
||||
content_index: Optional[ContentIndex],
|
||||
content_config: Optional[ContentConfig],
|
||||
search_models: SearchModels,
|
||||
regenerate: bool,
|
||||
t: Optional[state.SearchType] = None,
|
||||
) -> Optional[ContentIndex]:
|
||||
# Run Validation Checks
|
||||
if content_config is None:
|
||||
logger.warning("🚨 No Content configuration available.")
|
||||
return None
|
||||
if content_index is None:
|
||||
content_index = ContentIndex()
|
||||
|
||||
try:
|
||||
# Initialize Org Notes Search
|
||||
if (t == None or t.value == state.SearchType.Org.value) and content_config.org and search_models.text_search:
|
||||
logger.info("🦄 Setting up search for orgmode notes")
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl,
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Markdown Search
|
||||
if (
|
||||
(t == None or t.value == state.SearchType.Markdown.value)
|
||||
and content_config.markdown
|
||||
and search_models.text_search
|
||||
):
|
||||
logger.info("💎 Setting up search for markdown notes")
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
content_index.markdown = text_search.setup(
|
||||
MarkdownToJsonl,
|
||||
content_config.markdown,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize PDF Search
|
||||
if (t == None or t.value == state.SearchType.Pdf.value) and content_config.pdf and search_models.text_search:
|
||||
logger.info("🖨️ Setting up search for pdf")
|
||||
# Extract Entries, Generate PDF Embeddings
|
||||
content_index.pdf = text_search.setup(
|
||||
PdfToJsonl,
|
||||
content_config.pdf,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Plaintext Search
|
||||
if (
|
||||
(t == None or t.value == state.SearchType.Plaintext.value)
|
||||
and content_config.plaintext
|
||||
and search_models.text_search
|
||||
):
|
||||
logger.info("📄 Setting up search for plaintext")
|
||||
# Extract Entries, Generate Plaintext Embeddings
|
||||
content_index.plaintext = text_search.setup(
|
||||
PlaintextToJsonl,
|
||||
content_config.plaintext,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Image Search
|
||||
if (
|
||||
(t == None or t.value == state.SearchType.Image.value)
|
||||
and content_config.image
|
||||
and search_models.image_search
|
||||
):
|
||||
logger.info("🌄 Setting up search for images")
|
||||
# Extract Entries, Generate Image Embeddings
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
|
||||
)
|
||||
|
||||
if (
|
||||
(t == None or t.value == state.SearchType.Github.value)
|
||||
and content_config.github
|
||||
and search_models.text_search
|
||||
):
|
||||
logger.info("🐙 Setting up search for github")
|
||||
# Extract Entries, Generate Github Embeddings
|
||||
content_index.github = text_search.setup(
|
||||
GithubToJsonl,
|
||||
content_config.github,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Notion Search
|
||||
if (
|
||||
(t == None or t.value in state.SearchType.Notion.value)
|
||||
and content_config.notion
|
||||
and search_models.text_search
|
||||
):
|
||||
logger.info("🔌 Setting up search for notion")
|
||||
content_index.notion = text_search.setup(
|
||||
NotionToJsonl,
|
||||
content_config.notion,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize External Plugin Search
|
||||
if (t == None or t in state.SearchType) and content_config.plugins and search_models.text_search:
|
||||
logger.info("🔌 Setting up search for plugins")
|
||||
content_index.plugins = {}
|
||||
for plugin_type, plugin_config in content_config.plugins.items():
|
||||
content_index.plugins[plugin_type] = text_search.setup(
|
||||
JsonlToJsonl,
|
||||
plugin_config,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to setup search: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
# Invalidate Query Cache
|
||||
state.query_cache = LRU()
|
||||
|
||||
return content_index
|
||||
|
||||
|
||||
def configure_processor(
|
||||
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
|
||||
):
|
||||
|
|
|
@ -75,8 +75,8 @@ def run():
|
|||
poll_task_scheduler()
|
||||
|
||||
# Start Server
|
||||
initialize_server(args.config, args.regenerate, required=False)
|
||||
configure_routes(app)
|
||||
initialize_server(args.config, required=False)
|
||||
start_server(app, host=args.host, port=args.port, socket=args.socket)
|
||||
else:
|
||||
from PySide6 import QtWidgets
|
||||
|
@ -99,7 +99,7 @@ def run():
|
|||
tray.show()
|
||||
|
||||
# Setup Server
|
||||
initialize_server(args.config, args.regenerate, required=False)
|
||||
initialize_server(args.config, required=False)
|
||||
configure_routes(app)
|
||||
server = ServerThread(start_server_func=lambda: start_server(app, host=args.host, port=args.port), parent=gui)
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class GithubToJsonl(TextToJsonl):
|
|||
else:
|
||||
return
|
||||
|
||||
def process(self, previous_entries=[]):
|
||||
def process(self, previous_entries=[], files=None):
|
||||
if self.config.pat_token is None or self.config.pat_token == "":
|
||||
logger.error(f"Github PAT token is not set. Skipping github content")
|
||||
raise ValueError("Github PAT token is not set. Skipping github content")
|
||||
|
|
|
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class JsonlToJsonl(TextToJsonl):
|
||||
# Define Functions
|
||||
def process(self, previous_entries=[]):
|
||||
def process(self, previous_entries=[], files: dict[str, str] = {}):
|
||||
# Extract required fields from config
|
||||
input_jsonl_files, input_jsonl_filter, output_file = (
|
||||
self.config.input_files,
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Standard Packages
|
||||
import glob
|
||||
import logging
|
||||
import re
|
||||
import urllib3
|
||||
|
@ -8,7 +7,7 @@ from typing import List
|
|||
|
||||
# Internal Packages
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.jsonl import compress_jsonl_data
|
||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||
|
@ -23,26 +22,14 @@ class MarkdownToJsonl(TextToJsonl):
|
|||
self.config = config
|
||||
|
||||
# Define Functions
|
||||
def process(self, previous_entries=[]):
|
||||
def process(self, previous_entries=[], files=None):
|
||||
# Extract required fields from config
|
||||
markdown_files, markdown_file_filter, output_file = (
|
||||
self.config.input_files,
|
||||
self.config.input_filter,
|
||||
self.config.compressed_jsonl,
|
||||
)
|
||||
|
||||
# Input Validation
|
||||
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
|
||||
print("At least one of markdown-files or markdown-file-filter is required to be specified")
|
||||
exit(1)
|
||||
|
||||
# Get Markdown Files to Process
|
||||
markdown_files = MarkdownToJsonl.get_markdown_files(markdown_files, markdown_file_filter)
|
||||
output_file = self.config.compressed_jsonl
|
||||
|
||||
# Extract Entries from specified Markdown files
|
||||
with timer("Parse entries from Markdown files into dictionaries", logger):
|
||||
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(
|
||||
*MarkdownToJsonl.extract_markdown_entries(markdown_files)
|
||||
*MarkdownToJsonl.extract_markdown_entries(files)
|
||||
)
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
|
@ -65,36 +52,6 @@ class MarkdownToJsonl(TextToJsonl):
|
|||
|
||||
return entries_with_ids
|
||||
|
||||
@staticmethod
|
||||
def get_markdown_files(markdown_files=None, markdown_file_filters=None):
|
||||
"Get Markdown files to process"
|
||||
absolute_markdown_files, filtered_markdown_files = set(), set()
|
||||
if markdown_files:
|
||||
absolute_markdown_files = {get_absolute_path(markdown_file) for markdown_file in markdown_files}
|
||||
if markdown_file_filters:
|
||||
filtered_markdown_files = {
|
||||
filtered_file
|
||||
for markdown_file_filter in markdown_file_filters
|
||||
for filtered_file in glob.glob(get_absolute_path(markdown_file_filter), recursive=True)
|
||||
}
|
||||
|
||||
all_markdown_files = sorted(absolute_markdown_files | filtered_markdown_files)
|
||||
|
||||
files_with_non_markdown_extensions = {
|
||||
md_file
|
||||
for md_file in all_markdown_files
|
||||
if not md_file.endswith(".md") and not md_file.endswith(".markdown")
|
||||
}
|
||||
|
||||
if any(files_with_non_markdown_extensions):
|
||||
logger.warning(
|
||||
f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}"
|
||||
)
|
||||
|
||||
logger.debug(f"Processing files: {all_markdown_files}")
|
||||
|
||||
return all_markdown_files
|
||||
|
||||
@staticmethod
|
||||
def extract_markdown_entries(markdown_files):
|
||||
"Extract entries by heading from specified Markdown files"
|
||||
|
@ -104,15 +61,14 @@ class MarkdownToJsonl(TextToJsonl):
|
|||
entries = []
|
||||
entry_to_file_map = []
|
||||
for markdown_file in markdown_files:
|
||||
with open(markdown_file, "r", encoding="utf8") as f:
|
||||
try:
|
||||
markdown_content = f.read()
|
||||
entries, entry_to_file_map = MarkdownToJsonl.process_single_markdown_file(
|
||||
markdown_content, markdown_file, entries, entry_to_file_map
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to process file: {markdown_file}. This file will not be indexed.")
|
||||
logger.warning(e, exc_info=True)
|
||||
try:
|
||||
markdown_content = markdown_files[markdown_file]
|
||||
entries, entry_to_file_map = MarkdownToJsonl.process_single_markdown_file(
|
||||
markdown_content, markdown_file, entries, entry_to_file_map
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to process file: {markdown_file}. This file will not be indexed.")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
return entries, dict(entry_to_file_map)
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ class NotionToJsonl(TextToJsonl):
|
|||
|
||||
self.body_params = {"page_size": 100}
|
||||
|
||||
def process(self, previous_entries=[]):
|
||||
def process(self, previous_entries=[], files=None):
|
||||
current_entries = []
|
||||
|
||||
# Get all pages
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
# Standard Packages
|
||||
import glob
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List
|
||||
from typing import Iterable, List, Tuple
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.org_mode import orgnode
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.jsonl import compress_jsonl_data
|
||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||
from khoj.utils import state
|
||||
|
@ -22,27 +21,14 @@ class OrgToJsonl(TextToJsonl):
|
|||
self.config = config
|
||||
|
||||
# Define Functions
|
||||
def process(self, previous_entries: List[Entry] = []):
|
||||
def process(self, previous_entries: List[Entry] = [], files: dict[str, str] = None) -> List[Tuple[int, Entry]]:
|
||||
# Extract required fields from config
|
||||
org_files, org_file_filter, output_file = (
|
||||
self.config.input_files,
|
||||
self.config.input_filter,
|
||||
self.config.compressed_jsonl,
|
||||
)
|
||||
output_file = self.config.compressed_jsonl
|
||||
index_heading_entries = self.config.index_heading_entries
|
||||
|
||||
# Input Validation
|
||||
if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter):
|
||||
print("At least one of org-files or org-file-filter is required to be specified")
|
||||
exit(1)
|
||||
|
||||
# Get Org Files to Process
|
||||
with timer("Get org files to process", logger):
|
||||
org_files = OrgToJsonl.get_org_files(org_files, org_file_filter)
|
||||
|
||||
# Extract Entries from specified Org files
|
||||
with timer("Parse entries from org files into OrgNode objects", logger):
|
||||
entry_nodes, file_to_entries = self.extract_org_entries(org_files)
|
||||
entry_nodes, file_to_entries = self.extract_org_entries(files)
|
||||
|
||||
with timer("Convert OrgNodes into list of entries", logger):
|
||||
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
|
||||
|
@ -67,36 +53,15 @@ class OrgToJsonl(TextToJsonl):
|
|||
return entries_with_ids
|
||||
|
||||
@staticmethod
|
||||
def get_org_files(org_files=None, org_file_filters=None):
|
||||
"Get Org files to process"
|
||||
absolute_org_files, filtered_org_files = set(), set()
|
||||
if org_files:
|
||||
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}
|
||||
if org_file_filters:
|
||||
filtered_org_files = {
|
||||
filtered_file
|
||||
for org_file_filter in org_file_filters
|
||||
for filtered_file in glob.glob(get_absolute_path(org_file_filter), recursive=True)
|
||||
}
|
||||
|
||||
all_org_files = sorted(absolute_org_files | filtered_org_files)
|
||||
|
||||
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
|
||||
if any(files_with_non_org_extensions):
|
||||
logger.warning(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
|
||||
|
||||
logger.debug(f"Processing files: {all_org_files}")
|
||||
|
||||
return all_org_files
|
||||
|
||||
@staticmethod
|
||||
def extract_org_entries(org_files):
|
||||
def extract_org_entries(org_files: dict[str, str]):
|
||||
"Extract entries from specified Org files"
|
||||
entries = []
|
||||
entry_to_file_map = []
|
||||
entry_to_file_map: List[Tuple[orgnode.Orgnode, str]] = []
|
||||
for org_file in org_files:
|
||||
filename = org_file
|
||||
file = org_files[org_file]
|
||||
try:
|
||||
org_file_entries = orgnode.makelist_with_filepath(str(org_file))
|
||||
org_file_entries = orgnode.makelist(file, filename)
|
||||
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
|
||||
entries.extend(org_file_entries)
|
||||
except Exception as e:
|
||||
|
@ -109,7 +74,7 @@ class OrgToJsonl(TextToJsonl):
|
|||
def process_single_org_file(org_content: str, org_file: str, entries: List, entry_to_file_map: List):
|
||||
# Process single org file. The org parser assumes that the file is a single org file and reads it from a buffer. We'll split the raw conetnt of this file by new line to mimic the same behavior.
|
||||
try:
|
||||
org_file_entries = orgnode.makelist(org_content.split("\n"), org_file)
|
||||
org_file_entries = orgnode.makelist(org_content, org_file)
|
||||
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
|
||||
entries.extend(org_file_entries)
|
||||
return entries, entry_to_file_map
|
||||
|
|
|
@ -65,7 +65,10 @@ def makelist(file, filename):
|
|||
"""
|
||||
ctr = 0
|
||||
|
||||
f = file
|
||||
if type(file) == str:
|
||||
f = file.split("\n")
|
||||
else:
|
||||
f = file
|
||||
|
||||
todos = {
|
||||
"TODO": "",
|
||||
|
@ -199,7 +202,8 @@ def makelist(file, filename):
|
|||
# if we are in a heading
|
||||
if heading:
|
||||
# add the line to the bodytext
|
||||
bodytext += line
|
||||
bodytext += line.rstrip() + "\n\n" if line.strip() else ""
|
||||
# bodytext += line + "\n" if line.strip() else "\n"
|
||||
# else we are in the pre heading portion of the file
|
||||
elif line.strip():
|
||||
# so add the line to the introtext
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# Standard Packages
|
||||
import glob
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# External Packages
|
||||
|
@ -9,7 +8,7 @@ from langchain.document_loaders import PyPDFLoader
|
|||
|
||||
# Internal Packages
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.jsonl import compress_jsonl_data
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
@ -19,25 +18,13 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class PdfToJsonl(TextToJsonl):
|
||||
# Define Functions
|
||||
def process(self, previous_entries=[]):
|
||||
def process(self, previous_entries=[], files=dict[str, str]):
|
||||
# Extract required fields from config
|
||||
pdf_files, pdf_file_filter, output_file = (
|
||||
self.config.input_files,
|
||||
self.config.input_filter,
|
||||
self.config.compressed_jsonl,
|
||||
)
|
||||
|
||||
# Input Validation
|
||||
if is_none_or_empty(pdf_files) and is_none_or_empty(pdf_file_filter):
|
||||
print("At least one of pdf-files or pdf-file-filter is required to be specified")
|
||||
exit(1)
|
||||
|
||||
# Get Pdf Files to Process
|
||||
pdf_files = PdfToJsonl.get_pdf_files(pdf_files, pdf_file_filter)
|
||||
output_file = self.config.compressed_jsonl
|
||||
|
||||
# Extract Entries from specified Pdf files
|
||||
with timer("Parse entries from PDF files into dictionaries", logger):
|
||||
current_entries = PdfToJsonl.convert_pdf_entries_to_maps(*PdfToJsonl.extract_pdf_entries(pdf_files))
|
||||
current_entries = PdfToJsonl.convert_pdf_entries_to_maps(*PdfToJsonl.extract_pdf_entries(files))
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
|
@ -59,32 +46,6 @@ class PdfToJsonl(TextToJsonl):
|
|||
|
||||
return entries_with_ids
|
||||
|
||||
@staticmethod
|
||||
def get_pdf_files(pdf_files=None, pdf_file_filters=None):
|
||||
"Get PDF files to process"
|
||||
absolute_pdf_files, filtered_pdf_files = set(), set()
|
||||
if pdf_files:
|
||||
absolute_pdf_files = {get_absolute_path(pdf_file) for pdf_file in pdf_files}
|
||||
if pdf_file_filters:
|
||||
filtered_pdf_files = {
|
||||
filtered_file
|
||||
for pdf_file_filter in pdf_file_filters
|
||||
for filtered_file in glob.glob(get_absolute_path(pdf_file_filter), recursive=True)
|
||||
}
|
||||
|
||||
all_pdf_files = sorted(absolute_pdf_files | filtered_pdf_files)
|
||||
|
||||
files_with_non_pdf_extensions = {pdf_file for pdf_file in all_pdf_files if not pdf_file.endswith(".pdf")}
|
||||
|
||||
if any(files_with_non_pdf_extensions):
|
||||
logger.warning(
|
||||
f"[Warning] There maybe non pdf-mode files in the input set: {files_with_non_pdf_extensions}"
|
||||
)
|
||||
|
||||
logger.debug(f"Processing files: {all_pdf_files}")
|
||||
|
||||
return all_pdf_files
|
||||
|
||||
@staticmethod
|
||||
def extract_pdf_entries(pdf_files):
|
||||
"""Extract entries by page from specified PDF files"""
|
||||
|
@ -93,13 +54,19 @@ class PdfToJsonl(TextToJsonl):
|
|||
entry_to_location_map = []
|
||||
for pdf_file in pdf_files:
|
||||
try:
|
||||
loader = PyPDFLoader(pdf_file)
|
||||
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PyPDFLoader expects a file path
|
||||
with open(f"{pdf_file}.pdf", "wb") as f:
|
||||
f.write(pdf_files[pdf_file])
|
||||
loader = PyPDFLoader(f"{pdf_file}.pdf")
|
||||
pdf_entries_per_file = [page.page_content for page in loader.load()]
|
||||
entry_to_location_map += zip(pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file))
|
||||
entries.extend(pdf_entries_per_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.")
|
||||
logger.warning(e)
|
||||
finally:
|
||||
if os.path.exists(f"{pdf_file}.pdf"):
|
||||
os.remove(f"{pdf_file}.pdf")
|
||||
|
||||
return entries, dict(entry_to_location_map)
|
||||
|
||||
|
@ -108,9 +75,9 @@ class PdfToJsonl(TextToJsonl):
|
|||
"Convert each PDF entries into a dictionary"
|
||||
entries = []
|
||||
for parsed_entry in parsed_entries:
|
||||
entry_filename = Path(entry_to_file_map[parsed_entry])
|
||||
entry_filename = entry_to_file_map[parsed_entry]
|
||||
# Append base filename to compiled entry for context to model
|
||||
heading = f"{entry_filename.stem}\n"
|
||||
heading = f"{entry_filename}\n"
|
||||
compiled_entry = f"{heading}{parsed_entry}"
|
||||
entries.append(
|
||||
Entry(
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
# Standard Packages
|
||||
import glob
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.utils.helpers import get_absolute_path, timer
|
||||
from khoj.utils.jsonl import load_jsonl, compress_jsonl_data
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.jsonl import compress_jsonl_data
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
|
@ -16,22 +15,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class PlaintextToJsonl(TextToJsonl):
|
||||
# Define Functions
|
||||
def process(self, previous_entries=[]):
|
||||
# Extract required fields from config
|
||||
input_files, input_filter, output_file = (
|
||||
self.config.input_files,
|
||||
self.config.input_filter,
|
||||
self.config.compressed_jsonl,
|
||||
)
|
||||
|
||||
# Get Plaintext Input Files to Process
|
||||
all_input_plaintext_files = PlaintextToJsonl.get_plaintext_files(input_files, input_filter)
|
||||
def process(self, previous_entries: List[Entry] = [], files: dict[str, str] = None) -> List[Tuple[int, Entry]]:
|
||||
output_file = self.config.compressed_jsonl
|
||||
|
||||
# Extract Entries from specified plaintext files
|
||||
with timer("Parse entries from plaintext files", logger):
|
||||
current_entries = PlaintextToJsonl.convert_plaintext_entries_to_maps(
|
||||
PlaintextToJsonl.extract_plaintext_entries(all_input_plaintext_files)
|
||||
)
|
||||
current_entries = PlaintextToJsonl.convert_plaintext_entries_to_maps(files)
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
|
@ -53,67 +42,11 @@ class PlaintextToJsonl(TextToJsonl):
|
|||
|
||||
return entries_with_ids
|
||||
|
||||
@staticmethod
|
||||
def get_plaintext_files(plaintext_files=None, plaintext_file_filters=None):
|
||||
"Get all files to process"
|
||||
absolute_plaintext_files, filtered_plaintext_files = set(), set()
|
||||
if plaintext_files:
|
||||
absolute_plaintext_files = {get_absolute_path(jsonl_file) for jsonl_file in plaintext_files}
|
||||
if plaintext_file_filters:
|
||||
filtered_plaintext_files = {
|
||||
filtered_file
|
||||
for jsonl_file_filter in plaintext_file_filters
|
||||
for filtered_file in glob.glob(get_absolute_path(jsonl_file_filter), recursive=True)
|
||||
}
|
||||
|
||||
all_target_files = sorted(absolute_plaintext_files | filtered_plaintext_files)
|
||||
|
||||
files_with_no_plaintext_extensions = {
|
||||
target_files for target_files in all_target_files if not PlaintextToJsonl.is_plaintextfile(target_files)
|
||||
}
|
||||
if any(files_with_no_plaintext_extensions):
|
||||
logger.warn(f"Skipping unsupported files from plaintext indexing: {files_with_no_plaintext_extensions}")
|
||||
all_target_files = list(set(all_target_files) - files_with_no_plaintext_extensions)
|
||||
|
||||
logger.debug(f"Processing files: {all_target_files}")
|
||||
|
||||
return all_target_files
|
||||
|
||||
@staticmethod
|
||||
def is_plaintextfile(file: str):
|
||||
"Check if file is plaintext file"
|
||||
return file.endswith(("txt", "md", "markdown", "org", "mbox", "rst", "html", "htm", "xml"))
|
||||
|
||||
@staticmethod
|
||||
def extract_plaintext_entries(plaintext_files: List[str]):
|
||||
"Extract entries from specified plaintext files"
|
||||
entry_to_file_map = []
|
||||
|
||||
for plaintext_file in plaintext_files:
|
||||
with open(plaintext_file, "r") as f:
|
||||
try:
|
||||
plaintext_content = f.read()
|
||||
if plaintext_file.endswith(("html", "htm", "xml")):
|
||||
plaintext_content = PlaintextToJsonl.extract_html_content(plaintext_content)
|
||||
entry_to_file_map.append((plaintext_content, plaintext_file))
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file: {plaintext_file} - {e}", exc_info=True)
|
||||
|
||||
return dict(entry_to_file_map)
|
||||
|
||||
@staticmethod
|
||||
def extract_html_content(html_content: str):
|
||||
"Extract content from HTML"
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
soup = BeautifulSoup(html_content, "html.parser")
|
||||
return soup.get_text(strip=True, separator="\n")
|
||||
|
||||
@staticmethod
|
||||
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
|
||||
"Convert each plaintext entries into a dictionary"
|
||||
entries = []
|
||||
for entry, file in entry_to_file_map.items():
|
||||
for file, entry in entry_to_file_map.items():
|
||||
entries.append(
|
||||
Entry(
|
||||
raw=entry,
|
||||
|
|
|
@ -17,7 +17,7 @@ class TextToJsonl(ABC):
|
|||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def process(self, previous_entries: List[Entry] = []) -> List[Tuple[int, Entry]]:
|
||||
def process(self, previous_entries: List[Entry] = [], files: dict[str, str] = None) -> List[Tuple[int, Entry]]:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -608,7 +608,7 @@ def update(
|
|||
logger.warning(error_msg)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
try:
|
||||
configure_server(state.config, regenerate=force or False, search_type=t)
|
||||
configure_server(state.config)
|
||||
except Exception as e:
|
||||
error_msg = f"🚨 Failed to update server via API: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
|
@ -765,7 +765,7 @@ async def extract_references_and_questions(
|
|||
inferred_queries: List[str] = []
|
||||
|
||||
if state.content_index is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
|
||||
)
|
||||
return compiled_references, inferred_queries
|
||||
|
|
285
src/khoj/routers/indexer.py
Normal file
285
src/khoj/routers/indexer.py
Normal file
|
@ -0,0 +1,285 @@
|
|||
# Standard Packages
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter, HTTPException, Header, Request, Body, Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils import state
|
||||
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
|
||||
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
|
||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||
from khoj.utils.rawconfig import ContentConfig
|
||||
from khoj.search_type import text_search, image_search
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.helpers import LRU
|
||||
from khoj.utils.rawconfig import (
|
||||
ContentConfig,
|
||||
)
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.utils.config import (
|
||||
ContentIndex,
|
||||
SearchModels,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
indexer = APIRouter()
|
||||
|
||||
|
||||
class IndexBatchRequest(BaseModel):
|
||||
org: Optional[dict[str, str]]
|
||||
pdf: Optional[dict[str, str]]
|
||||
plaintext: Optional[dict[str, str]]
|
||||
markdown: Optional[dict[str, str]]
|
||||
|
||||
|
||||
@indexer.post("/batch")
|
||||
async def index_batch(
|
||||
request: Request,
|
||||
x_api_key: str = Header(None),
|
||||
regenerate: bool = False,
|
||||
search_type: Optional[Union[state.SearchType, str]] = None,
|
||||
):
|
||||
if x_api_key != "secret":
|
||||
raise HTTPException(status_code=401, detail="Invalid API Key")
|
||||
state.config_lock.acquire()
|
||||
try:
|
||||
logger.info(f"Received batch indexing request")
|
||||
index_batch_request_acc = ""
|
||||
async for chunk in request.stream():
|
||||
index_batch_request_acc += chunk.decode()
|
||||
index_batch_request = IndexBatchRequest.parse_raw(index_batch_request_acc)
|
||||
logger.info(f"Received batch indexing request size: {len(index_batch_request.dict())}")
|
||||
|
||||
# Extract required fields from config
|
||||
state.content_index = configure_content(
|
||||
state.content_index,
|
||||
state.config.content_type,
|
||||
index_batch_request.dict(),
|
||||
state.search_models,
|
||||
regenerate=regenerate,
|
||||
t=search_type,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process batch indexing request: {e}")
|
||||
finally:
|
||||
state.config_lock.release()
|
||||
return Response(content="OK", status_code=200)
|
||||
|
||||
|
||||
def configure_content(
|
||||
content_index: Optional[ContentIndex],
|
||||
content_config: Optional[ContentConfig],
|
||||
files: Optional[dict[str, dict[str, str]]],
|
||||
search_models: SearchModels,
|
||||
regenerate: bool = False,
|
||||
t: Optional[Union[state.SearchType, str]] = None,
|
||||
) -> Optional[ContentIndex]:
|
||||
# Run Validation Checks
|
||||
if content_config is None:
|
||||
logger.warning("🚨 No Content configuration available.")
|
||||
return None
|
||||
if content_index is None:
|
||||
content_index = ContentIndex()
|
||||
|
||||
if t in [type.value for type in state.SearchType]:
|
||||
t = state.SearchType(t).value
|
||||
|
||||
assert type(t) == str or t == None, f"Invalid search type: {t}"
|
||||
|
||||
if files is None:
|
||||
logger.warning(f"🚨 No files to process for {t} search.")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Initialize Org Notes Search
|
||||
if (
|
||||
(t == None or t == state.SearchType.Org.value)
|
||||
and content_config.org
|
||||
and search_models.text_search
|
||||
and files["org"]
|
||||
):
|
||||
logger.info("🦄 Setting up search for orgmode notes")
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl,
|
||||
files.get("org"),
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Markdown Search
|
||||
if (
|
||||
(t == None or t == state.SearchType.Markdown.value)
|
||||
and content_config.markdown
|
||||
and search_models.text_search
|
||||
and files["markdown"]
|
||||
):
|
||||
logger.info("💎 Setting up search for markdown notes")
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
content_index.markdown = text_search.setup(
|
||||
MarkdownToJsonl,
|
||||
files.get("markdown"),
|
||||
content_config.markdown,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize PDF Search
|
||||
if (
|
||||
(t == None or t == state.SearchType.Pdf.value)
|
||||
and content_config.pdf
|
||||
and search_models.text_search
|
||||
and files["pdf"]
|
||||
):
|
||||
logger.info("🖨️ Setting up search for pdf")
|
||||
# Extract Entries, Generate PDF Embeddings
|
||||
content_index.pdf = text_search.setup(
|
||||
PdfToJsonl,
|
||||
files.get("pdf"),
|
||||
content_config.pdf,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Plaintext Search
|
||||
if (
|
||||
(t == None or t == state.SearchType.Plaintext.value)
|
||||
and content_config.plaintext
|
||||
and search_models.text_search
|
||||
and files["plaintext"]
|
||||
):
|
||||
logger.info("📄 Setting up search for plaintext")
|
||||
# Extract Entries, Generate Plaintext Embeddings
|
||||
content_index.plaintext = text_search.setup(
|
||||
PlaintextToJsonl,
|
||||
files.get("plaintext"),
|
||||
content_config.plaintext,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Image Search
|
||||
if (t == None or t == state.SearchType.Image.value) and content_config.image and search_models.image_search:
|
||||
logger.info("🌄 Setting up search for images")
|
||||
# Extract Entries, Generate Image Embeddings
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
|
||||
)
|
||||
|
||||
if (t == None or t == state.SearchType.Github.value) and content_config.github and search_models.text_search:
|
||||
logger.info("🐙 Setting up search for github")
|
||||
# Extract Entries, Generate Github Embeddings
|
||||
content_index.github = text_search.setup(
|
||||
GithubToJsonl,
|
||||
None,
|
||||
content_config.github,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Notion Search
|
||||
if (t == None or t in state.SearchType.Notion.value) and content_config.notion and search_models.text_search:
|
||||
logger.info("🔌 Setting up search for notion")
|
||||
content_index.notion = text_search.setup(
|
||||
NotionToJsonl,
|
||||
None,
|
||||
content_config.notion,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize External Plugin Search
|
||||
if (t == None or t in state.SearchType) and content_config.plugins and search_models.text_search:
|
||||
logger.info("🔌 Setting up search for plugins")
|
||||
content_index.plugins = {}
|
||||
for plugin_type, plugin_config in content_config.plugins.items():
|
||||
content_index.plugins[plugin_type] = text_search.setup(
|
||||
JsonlToJsonl,
|
||||
None,
|
||||
plugin_config,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to setup search: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
# Invalidate Query Cache
|
||||
state.query_cache = LRU()
|
||||
|
||||
return content_index
|
||||
|
||||
|
||||
def load_content(
|
||||
content_config: Optional[ContentConfig],
|
||||
content_index: Optional[ContentIndex],
|
||||
search_models: SearchModels,
|
||||
):
|
||||
logger.info(f"Loading content from existing embeddings...")
|
||||
if content_config is None:
|
||||
logger.warning("🚨 No Content configuration available.")
|
||||
return None
|
||||
if content_index is None:
|
||||
content_index = ContentIndex()
|
||||
|
||||
if content_config.org:
|
||||
logger.info("🦄 Loading orgmode notes")
|
||||
content_index.org = text_search.load(content_config.org, filters=[DateFilter(), WordFilter(), FileFilter()])
|
||||
if content_config.markdown:
|
||||
logger.info("💎 Loading markdown notes")
|
||||
content_index.markdown = text_search.load(
|
||||
content_config.markdown, filters=[DateFilter(), WordFilter(), FileFilter()]
|
||||
)
|
||||
if content_config.pdf:
|
||||
logger.info("🖨️ Loading pdf")
|
||||
content_index.pdf = text_search.load(content_config.pdf, filters=[DateFilter(), WordFilter(), FileFilter()])
|
||||
if content_config.plaintext:
|
||||
logger.info("📄 Loading plaintext")
|
||||
content_index.plaintext = text_search.load(
|
||||
content_config.plaintext, filters=[DateFilter(), WordFilter(), FileFilter()]
|
||||
)
|
||||
if content_config.image:
|
||||
logger.info("🌄 Loading images")
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||
)
|
||||
if content_config.github:
|
||||
logger.info("🐙 Loading github")
|
||||
content_index.github = text_search.load(
|
||||
content_config.github, filters=[DateFilter(), WordFilter(), FileFilter()]
|
||||
)
|
||||
if content_config.notion:
|
||||
logger.info("🔌 Loading notion")
|
||||
content_index.notion = text_search.load(
|
||||
content_config.notion, filters=[DateFilter(), WordFilter(), FileFilter()]
|
||||
)
|
||||
if content_config.plugins:
|
||||
logger.info("🔌 Loading plugins")
|
||||
content_index.plugins = {}
|
||||
for plugin_type, plugin_config in content_config.plugins.items():
|
||||
content_index.plugins[plugin_type] = text_search.load(
|
||||
plugin_config, filters=[DateFilter(), WordFilter(), FileFilter()]
|
||||
)
|
||||
|
||||
state.query_cache = LRU()
|
||||
return content_index
|
|
@ -104,6 +104,18 @@ def compute_embeddings(
|
|||
return corpus_embeddings
|
||||
|
||||
|
||||
def load_embeddings(
|
||||
embeddings_file: Path,
|
||||
):
|
||||
"Load pre-computed embeddings from file if exists and update them if required"
|
||||
if embeddings_file.exists():
|
||||
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
|
||||
return util.normalize_embeddings(corpus_embeddings)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def query(
|
||||
raw_query: str,
|
||||
search_model: TextSearchModel,
|
||||
|
@ -174,6 +186,7 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
|
|||
|
||||
def setup(
|
||||
text_to_jsonl: Type[TextToJsonl],
|
||||
files: dict[str, str],
|
||||
config: TextConfigBase,
|
||||
bi_encoder: BaseEncoder,
|
||||
regenerate: bool,
|
||||
|
@ -185,7 +198,7 @@ def setup(
|
|||
previous_entries = []
|
||||
if config.compressed_jsonl.exists() and not regenerate:
|
||||
previous_entries = extract_entries(config.compressed_jsonl)
|
||||
entries_with_indices = text_to_jsonl(config).process(previous_entries)
|
||||
entries_with_indices = text_to_jsonl(config).process(previous_entries=previous_entries, files=files)
|
||||
|
||||
# Extract Updated Entries
|
||||
entries = extract_entries(config.compressed_jsonl)
|
||||
|
@ -205,6 +218,24 @@ def setup(
|
|||
return TextContent(entries, corpus_embeddings, filters)
|
||||
|
||||
|
||||
def load(
|
||||
config: TextConfigBase,
|
||||
filters: List[BaseFilter] = [],
|
||||
) -> TextContent:
|
||||
# Map notes in text files to (compressed) JSONL formatted file
|
||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
||||
entries = extract_entries(config.compressed_jsonl)
|
||||
|
||||
# Compute or Load Embeddings
|
||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
corpus_embeddings = load_embeddings(config.embeddings_file)
|
||||
|
||||
for filter in filters:
|
||||
filter.load(entries, regenerate=False)
|
||||
|
||||
return TextContent(entries, corpus_embeddings, filters)
|
||||
|
||||
|
||||
def apply_filters(
|
||||
query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]
|
||||
) -> Tuple[str, List[Entry], torch.Tensor]:
|
||||
|
|
217
src/khoj/utils/fs_syncer.py
Normal file
217
src/khoj/utils/fs_syncer.py
Normal file
|
@ -0,0 +1,217 @@
|
|||
import logging
|
||||
import glob
|
||||
from typing import Optional
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
from khoj.utils.rawconfig import TextContentConfig, ContentConfig
|
||||
from khoj.utils.config import SearchType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def collect_files(config: ContentConfig, search_type: Optional[SearchType] = SearchType.All):
|
||||
files = {}
|
||||
if search_type == SearchType.All or search_type == SearchType.Org:
|
||||
files["org"] = get_org_files(config.org) if config.org else {}
|
||||
if search_type == SearchType.All or search_type == SearchType.Markdown:
|
||||
files["markdown"] = get_markdown_files(config.markdown) if config.markdown else {}
|
||||
if search_type == SearchType.All or search_type == SearchType.Plaintext:
|
||||
files["plaintext"] = get_plaintext_files(config.plaintext) if config.plaintext else {}
|
||||
if search_type == SearchType.All or search_type == SearchType.Pdf:
|
||||
files["pdf"] = get_pdf_files(config.pdf) if config.pdf else {}
|
||||
return files
|
||||
|
||||
|
||||
def get_plaintext_files(config: TextContentConfig) -> dict[str, str]:
|
||||
def is_plaintextfile(file: str):
|
||||
"Check if file is plaintext file"
|
||||
return file.endswith(("txt", "md", "markdown", "org", "mbox", "rst", "html", "htm", "xml"))
|
||||
|
||||
def extract_html_content(html_content: str):
|
||||
"Extract content from HTML"
|
||||
soup = BeautifulSoup(html_content, "html.parser")
|
||||
return soup.get_text(strip=True, separator="\n")
|
||||
|
||||
# Extract required fields from config
|
||||
input_files, input_filter = (
|
||||
config.input_files,
|
||||
config.input_filter,
|
||||
)
|
||||
|
||||
# Input Validation
|
||||
if is_none_or_empty(input_files) and is_none_or_empty(input_filter):
|
||||
logger.debug("At least one of input-files or input-file-filter is required to be specified")
|
||||
return {}
|
||||
|
||||
"Get all files to process"
|
||||
absolute_plaintext_files, filtered_plaintext_files = set(), set()
|
||||
if input_files:
|
||||
absolute_plaintext_files = {get_absolute_path(jsonl_file) for jsonl_file in input_files}
|
||||
if input_filter:
|
||||
filtered_plaintext_files = {
|
||||
filtered_file
|
||||
for jsonl_file_filter in input_filter
|
||||
for filtered_file in glob.glob(get_absolute_path(jsonl_file_filter), recursive=True)
|
||||
}
|
||||
|
||||
all_target_files = sorted(absolute_plaintext_files | filtered_plaintext_files)
|
||||
|
||||
files_with_no_plaintext_extensions = {
|
||||
target_files for target_files in all_target_files if not is_plaintextfile(target_files)
|
||||
}
|
||||
if any(files_with_no_plaintext_extensions):
|
||||
logger.warning(f"Skipping unsupported files from plaintext indexing: {files_with_no_plaintext_extensions}")
|
||||
all_target_files = list(set(all_target_files) - files_with_no_plaintext_extensions)
|
||||
|
||||
logger.debug(f"Processing files: {all_target_files}")
|
||||
|
||||
filename_to_content_map = {}
|
||||
for file in all_target_files:
|
||||
with open(file, "r") as f:
|
||||
try:
|
||||
plaintext_content = f.read()
|
||||
if file.endswith(("html", "htm", "xml")):
|
||||
plaintext_content = extract_html_content(plaintext_content)
|
||||
filename_to_content_map[file] = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to read file: {file} as plaintext. Skipping file.")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
return filename_to_content_map
|
||||
|
||||
|
||||
def get_org_files(config: TextContentConfig):
|
||||
# Extract required fields from config
|
||||
org_files, org_file_filter = (
|
||||
config.input_files,
|
||||
config.input_filter,
|
||||
)
|
||||
|
||||
# Input Validation
|
||||
if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter):
|
||||
logger.debug("At least one of org-files or org-file-filter is required to be specified")
|
||||
return {}
|
||||
|
||||
"Get Org files to process"
|
||||
absolute_org_files, filtered_org_files = set(), set()
|
||||
if org_files:
|
||||
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}
|
||||
if org_file_filter:
|
||||
filtered_org_files = {
|
||||
filtered_file
|
||||
for org_file_filter in org_file_filter
|
||||
for filtered_file in glob.glob(get_absolute_path(org_file_filter), recursive=True)
|
||||
}
|
||||
|
||||
all_org_files = sorted(absolute_org_files | filtered_org_files)
|
||||
|
||||
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
|
||||
if any(files_with_non_org_extensions):
|
||||
logger.warning(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
|
||||
|
||||
logger.debug(f"Processing files: {all_org_files}")
|
||||
|
||||
filename_to_content_map = {}
|
||||
for file in all_org_files:
|
||||
with open(file, "r") as f:
|
||||
try:
|
||||
filename_to_content_map[file] = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to read file: {file} as org. Skipping file.")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
return filename_to_content_map
|
||||
|
||||
|
||||
def get_markdown_files(config: TextContentConfig):
|
||||
# Extract required fields from config
|
||||
markdown_files, markdown_file_filter = (
|
||||
config.input_files,
|
||||
config.input_filter,
|
||||
)
|
||||
|
||||
# Input Validation
|
||||
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
|
||||
logger.debug("At least one of markdown-files or markdown-file-filter is required to be specified")
|
||||
return {}
|
||||
|
||||
"Get Markdown files to process"
|
||||
absolute_markdown_files, filtered_markdown_files = set(), set()
|
||||
if markdown_files:
|
||||
absolute_markdown_files = {get_absolute_path(markdown_file) for markdown_file in markdown_files}
|
||||
|
||||
if markdown_file_filter:
|
||||
filtered_markdown_files = {
|
||||
filtered_file
|
||||
for markdown_file_filter in markdown_file_filter
|
||||
for filtered_file in glob.glob(get_absolute_path(markdown_file_filter), recursive=True)
|
||||
}
|
||||
|
||||
all_markdown_files = sorted(absolute_markdown_files | filtered_markdown_files)
|
||||
|
||||
files_with_non_markdown_extensions = {
|
||||
md_file for md_file in all_markdown_files if not md_file.endswith(".md") and not md_file.endswith(".markdown")
|
||||
}
|
||||
|
||||
if any(files_with_non_markdown_extensions):
|
||||
logger.warning(
|
||||
f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}"
|
||||
)
|
||||
|
||||
logger.debug(f"Processing files: {all_markdown_files}")
|
||||
|
||||
filename_to_content_map = {}
|
||||
for file in all_markdown_files:
|
||||
with open(file, "r") as f:
|
||||
try:
|
||||
filename_to_content_map[file] = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to read file: {file} as markdown. Skipping file.")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
return filename_to_content_map
|
||||
|
||||
|
||||
def get_pdf_files(config: TextContentConfig):
|
||||
# Extract required fields from config
|
||||
pdf_files, pdf_file_filter = (
|
||||
config.input_files,
|
||||
config.input_filter,
|
||||
)
|
||||
|
||||
# Input Validation
|
||||
if is_none_or_empty(pdf_files) and is_none_or_empty(pdf_file_filter):
|
||||
logger.debug("At least one of pdf-files or pdf-file-filter is required to be specified")
|
||||
return {}
|
||||
|
||||
"Get PDF files to process"
|
||||
absolute_pdf_files, filtered_pdf_files = set(), set()
|
||||
if pdf_files:
|
||||
absolute_pdf_files = {get_absolute_path(pdf_file) for pdf_file in pdf_files}
|
||||
if pdf_file_filter:
|
||||
filtered_pdf_files = {
|
||||
filtered_file
|
||||
for pdf_file_filter in pdf_file_filter
|
||||
for filtered_file in glob.glob(get_absolute_path(pdf_file_filter), recursive=True)
|
||||
}
|
||||
|
||||
all_pdf_files = sorted(absolute_pdf_files | filtered_pdf_files)
|
||||
|
||||
files_with_non_pdf_extensions = {pdf_file for pdf_file in all_pdf_files if not pdf_file.endswith(".pdf")}
|
||||
|
||||
if any(files_with_non_pdf_extensions):
|
||||
logger.warning(f"[Warning] There maybe non pdf-mode files in the input set: {files_with_non_pdf_extensions}")
|
||||
|
||||
logger.debug(f"Processing files: {all_pdf_files}")
|
||||
|
||||
filename_to_content_map = {}
|
||||
for file in all_pdf_files:
|
||||
with open(file, "rb") as f:
|
||||
try:
|
||||
filename_to_content_map[file] = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to read file: {file} as PDF. Skipping file.")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
return filename_to_content_map
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from khoj.main import app
|
||||
from khoj.configure import configure_processor, configure_routes, configure_search_types
|
||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.helpers import resolve_absolute_path
|
||||
|
@ -97,7 +98,12 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
|
|||
|
||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||
text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
|
||||
OrgToJsonl,
|
||||
get_sample_data("org"),
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
content_config.plugins = {
|
||||
|
@ -109,6 +115,20 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
|
|||
)
|
||||
}
|
||||
|
||||
if os.getenv("GITHUB_PAT_TOKEN"):
|
||||
content_config.github = GithubContentConfig(
|
||||
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
|
||||
repos=[
|
||||
GithubRepoConfig(
|
||||
owner="khoj-ai",
|
||||
name="lantern",
|
||||
branch="master",
|
||||
)
|
||||
],
|
||||
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
|
||||
)
|
||||
|
||||
content_config.plaintext = TextContentConfig(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"],
|
||||
|
@ -132,6 +152,7 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
|
|||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||
text_search.setup(
|
||||
JsonlToJsonl,
|
||||
None,
|
||||
content_config.plugins["plugin1"],
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
|
@ -203,6 +224,7 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
|
|||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
state.content_index.markdown = text_search.setup(
|
||||
MarkdownToJsonl,
|
||||
get_sample_data("markdown"),
|
||||
md_content_config.markdown,
|
||||
state.search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
|
@ -226,11 +248,22 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor
|
|||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
state.content_index.org = text_search.setup(
|
||||
OrgToJsonl, content_config.org, state.search_models.text_search.bi_encoder, regenerate=False
|
||||
OrgToJsonl,
|
||||
get_sample_data("org"),
|
||||
content_config.org,
|
||||
state.search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
)
|
||||
state.content_index.image = image_search.setup(
|
||||
content_config.image, state.search_models.image_search, regenerate=False
|
||||
)
|
||||
state.content_index.plaintext = text_search.setup(
|
||||
PlaintextToJsonl,
|
||||
get_sample_data("plaintext"),
|
||||
content_config.plaintext,
|
||||
state.search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
)
|
||||
|
||||
state.processor_config = configure_processor(processor_config)
|
||||
|
||||
|
@ -250,8 +283,21 @@ def client_offline_chat(
|
|||
# Index Markdown Content for Search
|
||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
state.content_index.org = text_search.setup(
|
||||
OrgToJsonl,
|
||||
get_sample_data("org"),
|
||||
content_config.org,
|
||||
state.search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
)
|
||||
state.content_index.image = image_search.setup(
|
||||
content_config.image, state.search_models.image_search, regenerate=False
|
||||
)
|
||||
|
||||
state.content_index.markdown = text_search.setup(
|
||||
MarkdownToJsonl,
|
||||
get_sample_data("markdown"),
|
||||
md_content_config.markdown,
|
||||
state.search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
|
@ -284,3 +330,69 @@ def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: P
|
|||
new_org_config.input_files = [f"{new_org_file}"]
|
||||
new_org_config.input_filter = None
|
||||
return new_org_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def sample_org_data():
|
||||
return get_sample_data("org")
|
||||
|
||||
|
||||
def get_sample_data(type):
|
||||
sample_data = {
|
||||
"org": {
|
||||
"readme.org": """
|
||||
* Khoj
|
||||
/Allow natural language search on user content like notes, images using transformer based models/
|
||||
|
||||
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
|
||||
|
||||
** Dependencies
|
||||
- Python3
|
||||
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
|
||||
|
||||
** Install
|
||||
#+begin_src shell
|
||||
git clone https://github.com/khoj-ai/khoj && cd khoj
|
||||
conda env create -f environment.yml
|
||||
conda activate khoj
|
||||
#+end_src"""
|
||||
},
|
||||
"markdown": {
|
||||
"readme.markdown": """
|
||||
# Khoj
|
||||
Allow natural language search on user content like notes, images using transformer based models
|
||||
|
||||
All data is processed locally. User can interface with khoj app via [Emacs](./interface/emacs/khoj.el), API or Commandline
|
||||
|
||||
## Dependencies
|
||||
- Python3
|
||||
- [Miniconda](https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links)
|
||||
|
||||
## Install
|
||||
```shell
|
||||
git clone
|
||||
conda env create -f environment.yml
|
||||
conda activate khoj
|
||||
```
|
||||
"""
|
||||
},
|
||||
"plaintext": {
|
||||
"readme.txt": """
|
||||
Khoj
|
||||
Allow natural language search on user content like notes, images using transformer based models
|
||||
|
||||
All data is processed locally. User can interface with khoj app via Emacs, API or Commandline
|
||||
|
||||
Dependencies
|
||||
- Python3
|
||||
- Miniconda
|
||||
|
||||
Install
|
||||
git clone
|
||||
conda env create -f environment.yml
|
||||
conda activate khoj
|
||||
"""
|
||||
},
|
||||
}
|
||||
|
||||
return sample_data[type]
|
||||
|
|
|
@ -11,7 +11,6 @@ from fastapi.testclient import TestClient
|
|||
from khoj.main import app
|
||||
from khoj.configure import configure_routes, configure_search_types
|
||||
from khoj.utils import state
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.state import search_models, content_index, config
|
||||
from khoj.search_type import text_search, image_search
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
|
@ -51,28 +50,6 @@ def test_update_with_invalid_content_type(client):
|
|||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_update_with_valid_content_type(client):
|
||||
for content_type in ["all", "org", "markdown", "image", "pdf", "notion", "plugin1"]:
|
||||
# Act
|
||||
response = client.get(f"/api/update?t={content_type}")
|
||||
# Assert
|
||||
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_update_with_github_fails_without_pat(client):
|
||||
# Act
|
||||
response = client.get(f"/api/update?t=github")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 500, f"Returned status: {response.status_code} for content type: github"
|
||||
assert (
|
||||
response.json()["detail"]
|
||||
== "🚨 Failed to update server via API: Github PAT token is not set. Skipping github content"
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_regenerate_with_invalid_content_type(client):
|
||||
# Act
|
||||
|
@ -82,11 +59,29 @@ def test_regenerate_with_invalid_content_type(client):
|
|||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_index_batch(client):
|
||||
# Arrange
|
||||
request_body = get_sample_files_data()
|
||||
headers = {"x-api-key": "secret"}
|
||||
|
||||
# Act
|
||||
response = client.post("/indexer/batch", json=request_body, headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_regenerate_with_valid_content_type(client):
|
||||
for content_type in ["all", "org", "markdown", "image", "pdf", "notion", "plugin1"]:
|
||||
# Arrange
|
||||
request_body = get_sample_files_data()
|
||||
|
||||
headers = {"x-api-key": "secret"}
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/update?force=true&t={content_type}")
|
||||
response = client.post(f"/indexer/batch?search_type={content_type}", json=request_body, headers=headers)
|
||||
# Assert
|
||||
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
|
||||
|
||||
|
@ -96,12 +91,15 @@ def test_regenerate_with_github_fails_without_pat(client):
|
|||
# Act
|
||||
response = client.get(f"/api/update?force=true&t=github")
|
||||
|
||||
# Arrange
|
||||
request_body = get_sample_files_data()
|
||||
|
||||
headers = {"x-api-key": "secret"}
|
||||
|
||||
# Act
|
||||
response = client.post(f"/indexer/batch?search_type=github", json=request_body, headers=headers)
|
||||
# Assert
|
||||
assert response.status_code == 500, f"Returned status: {response.status_code} for content type: github"
|
||||
assert (
|
||||
response.json()["detail"]
|
||||
== "🚨 Failed to update server via API: Github PAT token is not set. Skipping github content"
|
||||
)
|
||||
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: github"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -111,7 +109,7 @@ def test_get_configured_types_via_api(client):
|
|||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["all", "org", "image", "plugin1"]
|
||||
assert response.json() == ["all", "org", "image", "plaintext", "plugin1"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -194,11 +192,11 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data):
|
||||
# Arrange
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
||||
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
user_query = quote("How to git install application?")
|
||||
|
||||
|
@ -213,12 +211,19 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_only_filters(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||
def test_notes_search_with_only_filters(
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
||||
):
|
||||
# Arrange
|
||||
filters = [WordFilter(), FileFilter()]
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
|
||||
OrgToJsonl,
|
||||
sample_org_data,
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
)
|
||||
user_query = quote('+"Emacs" file:"*.org"')
|
||||
|
||||
|
@ -233,12 +238,14 @@ def test_notes_search_with_only_filters(client, content_config: ContentConfig, s
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_include_filter(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||
def test_notes_search_with_include_filter(
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
||||
):
|
||||
# Arrange
|
||||
filters = [WordFilter()]
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search, regenerate=False, filters=filters
|
||||
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search, regenerate=False, filters=filters
|
||||
)
|
||||
user_query = quote('How to git install application? +"Emacs"')
|
||||
|
||||
|
@ -253,12 +260,19 @@ def test_notes_search_with_include_filter(client, content_config: ContentConfig,
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_exclude_filter(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||
def test_notes_search_with_exclude_filter(
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
||||
):
|
||||
# Arrange
|
||||
filters = [WordFilter()]
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
|
||||
OrgToJsonl,
|
||||
sample_org_data,
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
)
|
||||
user_query = quote('How to git install application? -"clone"')
|
||||
|
||||
|
@ -270,3 +284,28 @@ def test_notes_search_with_exclude_filter(client, content_config: ContentConfig,
|
|||
# assert actual_data does not contains word "clone"
|
||||
search_result = response.json()[0]["entry"]
|
||||
assert "clone" not in search_result
|
||||
|
||||
|
||||
def get_sample_files_data():
|
||||
return {
|
||||
"org": {
|
||||
"path/to/filename.org": "* practicing piano",
|
||||
"path/to/filename1.org": "** top 3 reasons why I moved to SF",
|
||||
"path/to/filename2.org": "* how to build a search engine",
|
||||
},
|
||||
"pdf": {
|
||||
"path/to/filename.pdf": "Moore's law does not apply to consumer hardware",
|
||||
"path/to/filename1.pdf": "The sun is a ball of helium",
|
||||
"path/to/filename2.pdf": "Effect of sunshine on baseline human happiness",
|
||||
},
|
||||
"plaintext": {
|
||||
"path/to/filename.txt": "data,column,value",
|
||||
"path/to/filename1.txt": "<html>my first web page</html>",
|
||||
"path/to/filename2.txt": "2021-02-02 Journal Entry",
|
||||
},
|
||||
"markdown": {
|
||||
"path/to/filename.md": "# Notes from client call",
|
||||
"path/to/filename1.md": "## Studying anthropological records from the Fatimid caliphate",
|
||||
"path/to/filename2.md": "**Understanding science through the lens of art**",
|
||||
},
|
||||
}
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
# Standard Packages
|
||||
import json
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||
from khoj.utils.fs_syncer import get_markdown_files
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
|
||||
|
||||
def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
|
||||
|
@ -13,12 +16,14 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
|
|||
- Bullet point 1
|
||||
- Bullet point 2
|
||||
"""
|
||||
markdownfile = create_file(tmp_path, entry)
|
||||
expected_heading = "# " + markdownfile.stem
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
expected_heading = f"# {tmp_path.stem}"
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entry_nodes, file_to_entries = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile])
|
||||
entry_nodes, file_to_entries = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
|
||||
|
@ -41,11 +46,13 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
|
|||
\t\r
|
||||
Body Line 1
|
||||
"""
|
||||
markdownfile = create_file(tmp_path, entry)
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile])
|
||||
entries, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
|
||||
|
@ -68,11 +75,13 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
|
|||
\t\r
|
||||
Heading 2 Body Line 2
|
||||
"""
|
||||
markdownfile = create_file(tmp_path, entry)
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entry_strings, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile])
|
||||
entry_strings, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
|
||||
entries = MarkdownToJsonl.convert_markdown_entries_to_maps(entry_strings, entry_to_file_map)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
|
@ -82,7 +91,7 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
|
|||
# Assert
|
||||
assert len(jsonl_data) == 2
|
||||
# Ensure entry compiled strings include the markdown files they originate from
|
||||
assert all([markdownfile.stem in entry.compiled for entry in entries])
|
||||
assert all([tmp_path.stem in entry.compiled for entry in entries])
|
||||
|
||||
|
||||
def test_get_markdown_files(tmp_path):
|
||||
|
@ -99,18 +108,27 @@ def test_get_markdown_files(tmp_path):
|
|||
create_file(tmp_path, filename="not-included-markdown.md")
|
||||
create_file(tmp_path, filename="not-included-text.txt")
|
||||
|
||||
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
|
||||
expected_files = set(
|
||||
[os.path.join(tmp_path, file.name) for file in [group1_file1, group1_file2, group2_file1, group2_file2, file1]]
|
||||
)
|
||||
|
||||
# Setup input-files, input-filters
|
||||
input_files = [tmp_path / "notes.md"]
|
||||
input_filter = [tmp_path / "group1*.md", tmp_path / "group2*.markdown"]
|
||||
|
||||
markdown_config = TextContentConfig(
|
||||
input_files=input_files,
|
||||
input_filter=[str(filter) for filter in input_filter],
|
||||
compressed_jsonl=tmp_path / "test.jsonl",
|
||||
embeddings_file=tmp_path / "test_embeddings.jsonl",
|
||||
)
|
||||
|
||||
# Act
|
||||
extracted_org_files = MarkdownToJsonl.get_markdown_files(input_files, input_filter)
|
||||
extracted_org_files = get_markdown_files(markdown_config)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_org_files) == 5
|
||||
assert extracted_org_files == expected_files
|
||||
assert set(extracted_org_files.keys()) == expected_files
|
||||
|
||||
|
||||
def test_extract_entries_with_different_level_headings(tmp_path):
|
||||
|
@ -120,11 +138,13 @@ def test_extract_entries_with_different_level_headings(tmp_path):
|
|||
# Heading 1
|
||||
## Heading 2
|
||||
"""
|
||||
markdownfile = create_file(tmp_path, entry)
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries, _ = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile])
|
||||
entries, _ = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
# Standard Packages
|
||||
import json
|
||||
import os
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from khoj.utils.fs_syncer import get_org_files
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
|
||||
|
||||
def test_configure_heading_entry_to_jsonl(tmp_path):
|
||||
|
@ -18,14 +21,17 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
|
|||
:END:
|
||||
\t \r
|
||||
"""
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
for index_heading_entries in [True, False]:
|
||||
# Act
|
||||
# Extract entries into jsonl from specified Org files
|
||||
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
|
||||
OrgToJsonl.convert_org_nodes_to_entries(
|
||||
*OrgToJsonl.extract_org_entries(org_files=[orgfile]), index_heading_entries=index_heading_entries
|
||||
*OrgToJsonl.extract_org_entries(org_files=data), index_heading_entries=index_heading_entries
|
||||
)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
@ -46,12 +52,14 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
|
|||
\t\r
|
||||
Body Line
|
||||
"""
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
expected_heading = f"* {orgfile.stem}\n** Heading"
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
expected_heading = f"* {tmp_path.stem}\n** Heading"
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile])
|
||||
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=data)
|
||||
|
||||
# Split each entry from specified Org files by max words
|
||||
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
|
||||
|
@ -95,11 +103,13 @@ def test_entry_with_body_to_jsonl(tmp_path):
|
|||
\t\r
|
||||
Body Line 1
|
||||
"""
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile])
|
||||
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
|
||||
|
@ -120,11 +130,13 @@ Intro text
|
|||
* Entry Heading
|
||||
entry body
|
||||
"""
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entry_nodes, file_to_entries = OrgToJsonl.extract_org_entries(org_files=[orgfile])
|
||||
entry_nodes, file_to_entries = OrgToJsonl.extract_org_entries(org_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = OrgToJsonl.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
||||
|
@ -142,11 +154,13 @@ def test_file_with_no_headings_to_jsonl(tmp_path):
|
|||
- Bullet point 1
|
||||
- Bullet point 2
|
||||
"""
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entry_nodes, file_to_entries = OrgToJsonl.extract_org_entries(org_files=[orgfile])
|
||||
entry_nodes, file_to_entries = OrgToJsonl.extract_org_entries(org_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = OrgToJsonl.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
||||
|
@ -171,18 +185,30 @@ def test_get_org_files(tmp_path):
|
|||
create_file(tmp_path, filename="orgfile2.org")
|
||||
create_file(tmp_path, filename="text1.txt")
|
||||
|
||||
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1]))
|
||||
expected_files = set(
|
||||
[
|
||||
os.path.join(tmp_path, file.name)
|
||||
for file in [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1]
|
||||
]
|
||||
)
|
||||
|
||||
# Setup input-files, input-filters
|
||||
input_files = [tmp_path / "orgfile1.org"]
|
||||
input_filter = [tmp_path / "group1*.org", tmp_path / "group2*.org"]
|
||||
|
||||
org_config = TextContentConfig(
|
||||
input_files=input_files,
|
||||
input_filter=[str(filter) for filter in input_filter],
|
||||
compressed_jsonl=tmp_path / "test.jsonl",
|
||||
embeddings_file=tmp_path / "test_embeddings.jsonl",
|
||||
)
|
||||
|
||||
# Act
|
||||
extracted_org_files = OrgToJsonl.get_org_files(input_files, input_filter)
|
||||
extracted_org_files = get_org_files(org_config)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_org_files) == 5
|
||||
assert extracted_org_files == expected_files
|
||||
assert set(extracted_org_files.keys()) == expected_files
|
||||
|
||||
|
||||
def test_extract_entries_with_different_level_headings(tmp_path):
|
||||
|
@ -192,11 +218,13 @@ def test_extract_entries_with_different_level_headings(tmp_path):
|
|||
* Heading 1
|
||||
** Heading 2
|
||||
"""
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries, _ = OrgToJsonl.extract_org_entries(org_files=[orgfile])
|
||||
entries, _ = OrgToJsonl.extract_org_entries(org_files=data)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
|
|
|
@ -44,7 +44,7 @@ Body Line 1"""
|
|||
assert len(entries) == 1
|
||||
assert entries[0].heading == "Heading"
|
||||
assert entries[0].tags == list()
|
||||
assert entries[0].body == "Body Line 1"
|
||||
assert entries[0].body == "Body Line 1\n\n"
|
||||
assert entries[0].priority == ""
|
||||
assert entries[0].Property("ID") == ""
|
||||
assert entries[0].closed == ""
|
||||
|
@ -78,7 +78,7 @@ Body Line 2"""
|
|||
assert entries[0].heading == "Heading"
|
||||
assert entries[0].todo == "DONE"
|
||||
assert entries[0].tags == ["Tag1", "TAG2", "tag3"]
|
||||
assert entries[0].body == "- Clocked Log 1\nBody Line 1\nBody Line 2"
|
||||
assert entries[0].body == "- Clocked Log 1\n\nBody Line 1\n\nBody Line 2\n\n"
|
||||
assert entries[0].priority == "A"
|
||||
assert entries[0].Property("ID") == "id:123-456-789-4234-1231"
|
||||
assert entries[0].closed == datetime.date(1984, 4, 1)
|
||||
|
@ -205,7 +205,7 @@ Body 2
|
|||
assert entry.heading == f"Heading{index+1}"
|
||||
assert entry.todo == "FAILED" if index == 0 else "CANCELLED"
|
||||
assert entry.tags == [f"tag{index+1}"]
|
||||
assert entry.body == f"- Clocked Log {index+1}\nBody {index+1}\n\n"
|
||||
assert entry.body == f"- Clocked Log {index+1}\n\nBody {index+1}\n\n"
|
||||
assert entry.priority == "A"
|
||||
assert entry.Property("ID") == f"id:123-456-789-4234-000{index+1}"
|
||||
assert entry.closed == datetime.date(1984, 4, index + 1)
|
||||
|
@ -305,7 +305,7 @@ entry body
|
|||
assert entries[0].heading == "Title"
|
||||
assert entries[0].body == "intro body\n"
|
||||
assert entries[1].heading == "Entry Heading"
|
||||
assert entries[1].body == "entry body\n"
|
||||
assert entries[1].body == "entry body\n\n"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -327,7 +327,7 @@ entry body
|
|||
assert entries[0].heading == "Title1 Title2"
|
||||
assert entries[0].body == "intro body\n"
|
||||
assert entries[1].heading == "Entry Heading"
|
||||
assert entries[1].body == "entry body\n"
|
||||
assert entries[1].body == "entry body\n\n"
|
||||
|
||||
|
||||
# Helper Functions
|
||||
|
|
|
@ -1,15 +1,24 @@
|
|||
# Standard Packages
|
||||
import json
|
||||
import os
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
|
||||
|
||||
from khoj.utils.fs_syncer import get_pdf_files
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
|
||||
|
||||
def test_single_page_pdf_to_jsonl():
|
||||
"Convert single page PDF file to jsonl."
|
||||
# Act
|
||||
# Extract Entries from specified Pdf files
|
||||
entries, entry_to_file_map = PdfToJsonl.extract_pdf_entries(pdf_files=["tests/data/pdf/singlepage.pdf"])
|
||||
# Read singlepage.pdf into memory as bytes
|
||||
with open("tests/data/pdf/singlepage.pdf", "rb") as f:
|
||||
pdf_bytes = f.read()
|
||||
|
||||
data = {"tests/data/pdf/singlepage.pdf": pdf_bytes}
|
||||
entries, entry_to_file_map = PdfToJsonl.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Process Each Entry from All Pdf Files
|
||||
jsonl_string = PdfToJsonl.convert_pdf_maps_to_jsonl(
|
||||
|
@ -25,7 +34,11 @@ def test_multi_page_pdf_to_jsonl():
|
|||
"Convert multiple pages from single PDF file to jsonl."
|
||||
# Act
|
||||
# Extract Entries from specified Pdf files
|
||||
entries, entry_to_file_map = PdfToJsonl.extract_pdf_entries(pdf_files=["tests/data/pdf/multipage.pdf"])
|
||||
with open("tests/data/pdf/multipage.pdf", "rb") as f:
|
||||
pdf_bytes = f.read()
|
||||
|
||||
data = {"tests/data/pdf/multipage.pdf": pdf_bytes}
|
||||
entries, entry_to_file_map = PdfToJsonl.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Process Each Entry from All Pdf Files
|
||||
jsonl_string = PdfToJsonl.convert_pdf_maps_to_jsonl(
|
||||
|
@ -51,18 +64,27 @@ def test_get_pdf_files(tmp_path):
|
|||
create_file(tmp_path, filename="not-included-document.pdf")
|
||||
create_file(tmp_path, filename="not-included-text.txt")
|
||||
|
||||
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
|
||||
expected_files = set(
|
||||
[os.path.join(tmp_path, file.name) for file in [group1_file1, group1_file2, group2_file1, group2_file2, file1]]
|
||||
)
|
||||
|
||||
# Setup input-files, input-filters
|
||||
input_files = [tmp_path / "document.pdf"]
|
||||
input_filter = [tmp_path / "group1*.pdf", tmp_path / "group2*.pdf"]
|
||||
|
||||
pdf_config = TextContentConfig(
|
||||
input_files=input_files,
|
||||
input_filter=[str(path) for path in input_filter],
|
||||
compressed_jsonl=tmp_path / "test.jsonl",
|
||||
embeddings_file=tmp_path / "test_embeddings.jsonl",
|
||||
)
|
||||
|
||||
# Act
|
||||
extracted_pdf_files = PdfToJsonl.get_pdf_files(input_files, input_filter)
|
||||
extracted_pdf_files = get_pdf_files(pdf_config)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_pdf_files) == 5
|
||||
assert extracted_pdf_files == expected_files
|
||||
assert set(extracted_pdf_files.keys()) == expected_files
|
||||
|
||||
|
||||
# Helper Functions
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
# Standard Packages
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.fs_syncer import get_plaintext_files
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||
|
||||
|
||||
|
@ -18,9 +21,12 @@ def test_plaintext_file(tmp_path):
|
|||
|
||||
# Act
|
||||
# Extract Entries from specified plaintext files
|
||||
file_to_entries = PlaintextToJsonl.extract_plaintext_entries(plaintext_files=[str(plaintextfile)])
|
||||
|
||||
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(file_to_entries)
|
||||
data = {
|
||||
f"{plaintextfile}": entry,
|
||||
}
|
||||
|
||||
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(entry_to_file_map=data)
|
||||
|
||||
# Convert each entry.file to absolute path to make them JSON serializable
|
||||
for map in maps:
|
||||
|
@ -59,33 +65,40 @@ def test_get_plaintext_files(tmp_path):
|
|||
create_file(tmp_path, filename="not-included-markdown.md")
|
||||
create_file(tmp_path, filename="not-included-text.txt")
|
||||
|
||||
expected_files = sorted(
|
||||
map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1, group2_file3, group2_file4])
|
||||
expected_files = set(
|
||||
[
|
||||
os.path.join(tmp_path, file.name)
|
||||
for file in [group1_file1, group1_file2, group2_file1, group2_file2, group2_file3, group2_file4, file1]
|
||||
]
|
||||
)
|
||||
|
||||
# Setup input-files, input-filters
|
||||
input_files = [tmp_path / "notes.txt"]
|
||||
input_filter = [tmp_path / "group1*.md", tmp_path / "group2*.*"]
|
||||
|
||||
plaintext_config = TextContentConfig(
|
||||
input_files=input_files,
|
||||
input_filter=[str(filter) for filter in input_filter],
|
||||
compressed_jsonl=tmp_path / "test.jsonl",
|
||||
embeddings_file=tmp_path / "test_embeddings.jsonl",
|
||||
)
|
||||
|
||||
# Act
|
||||
extracted_plaintext_files = PlaintextToJsonl.get_plaintext_files(input_files, input_filter)
|
||||
extracted_plaintext_files = get_plaintext_files(plaintext_config)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_plaintext_files) == 7
|
||||
assert set(extracted_plaintext_files) == set(expected_files)
|
||||
assert set(extracted_plaintext_files.keys()) == set(expected_files)
|
||||
|
||||
|
||||
def test_parse_html_plaintext_file(content_config):
|
||||
"Ensure HTML files are parsed correctly"
|
||||
# Arrange
|
||||
# Setup input-files, input-filters
|
||||
input_files = content_config.plaintext.input_files
|
||||
input_filter = content_config.plaintext.input_filter
|
||||
extracted_plaintext_files = get_plaintext_files(content_config.plaintext)
|
||||
|
||||
# Act
|
||||
extracted_plaintext_files = PlaintextToJsonl.get_plaintext_files(input_files, input_filter)
|
||||
file_to_entries = PlaintextToJsonl.extract_plaintext_entries(extracted_plaintext_files)
|
||||
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(file_to_entries)
|
||||
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(extracted_plaintext_files)
|
||||
|
||||
# Assert
|
||||
assert len(maps) == 1
|
||||
|
|
|
@ -13,6 +13,7 @@ from khoj.search_type import text_search
|
|||
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||
from khoj.utils.fs_syncer import get_org_files
|
||||
|
||||
|
||||
# Test
|
||||
|
@ -27,26 +28,30 @@ def test_text_search_setup_with_missing_file_raises_error(
|
|||
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with pytest.raises(ValueError, match=r"^No valid entries found in specified files:*"):
|
||||
text_search.setup(OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_text_search_setup_with_empty_file_raises_error(
|
||||
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
|
||||
):
|
||||
# Arrange
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with pytest.raises(ValueError, match=r"^No valid entries found*"):
|
||||
text_search.setup(OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
|
||||
text_search.setup(OrgToJsonl, data, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
# Act
|
||||
# Regenerate notes embeddings during asymmetric setup
|
||||
notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
@ -59,14 +64,16 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, sea
|
|||
# Arrange
|
||||
caplog.set_level(logging.INFO, logger="khoj")
|
||||
|
||||
data = get_org_files(content_config.org)
|
||||
|
||||
# Act
|
||||
# Generate initial notes embeddings during asymmetric setup
|
||||
text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
|
||||
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
|
||||
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
|
||||
text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
|
||||
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
|
||||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
|
@ -78,9 +85,11 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, sea
|
|||
@pytest.mark.anyio
|
||||
async def test_text_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
query = "How to git install application?"
|
||||
|
||||
|
@ -108,10 +117,12 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
|
|||
for index in range(max_tokens + 1):
|
||||
f.write(f"{index} ")
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# Act
|
||||
# reload embeddings, entries, notes model after adding new org-mode file
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
@ -125,8 +136,9 @@ def test_regenerate_index_with_new_entry(
|
|||
content_config: ContentConfig, search_models: SearchModels, new_org_file: Path
|
||||
):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
|
||||
assert len(initial_notes_model.entries) == 10
|
||||
|
@ -137,10 +149,12 @@ def test_regenerate_index_with_new_entry(
|
|||
with open(new_org_file, "w") as f:
|
||||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||
|
||||
data = get_org_files(content_config.org)
|
||||
|
||||
# Act
|
||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||
regenerated_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
@ -169,15 +183,19 @@ def test_update_index_with_duplicate_entries_in_stable_order(
|
|||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"{new_entry}{new_entry}")
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# Act
|
||||
# load embeddings, entries, notes model after adding new org-mode file
|
||||
initial_index = text_search.setup(
|
||||
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# update embeddings, entries, notes model after adding new org-mode file
|
||||
updated_index = text_search.setup(
|
||||
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
@ -200,19 +218,22 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
|
|||
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"{new_entry}{new_entry} -- Tatooine")
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||
initial_index = text_search.setup(
|
||||
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
|
||||
# update embeddings, entries, notes model after removing an entry from the org file
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"{new_entry}")
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# Act
|
||||
updated_index = text_search.setup(
|
||||
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
@ -229,8 +250,9 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False
|
||||
)
|
||||
|
||||
# append org-mode entry to first org input file in config
|
||||
|
@ -238,11 +260,13 @@ def test_update_index_with_new_entry(content_config: ContentConfig, search_model
|
|||
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
f.write(new_entry)
|
||||
|
||||
data = get_org_files(content_config.org)
|
||||
|
||||
# Act
|
||||
# update embeddings, entries with the newly added note
|
||||
content_config.org.input_files = [f"{new_org_file}"]
|
||||
final_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
|
Loading…
Reference in a new issue