Merge intermediate changes

This commit is contained in:
Saba 2022-09-14 21:09:30 +03:00
commit f12ca56e93
47 changed files with 1874 additions and 744 deletions

View file

@ -11,6 +11,10 @@ on:
- Dockerfile
- docker-compose.yml
- .github/workflows/build.yml
workflow_dispatch:
env:
DOCKER_IMAGE_TAG: ${{ github.ref == 'refs/heads/master' && 'latest' || github.ref }}
jobs:
build:
@ -36,6 +40,6 @@ jobs:
context: .
file: Dockerfile
push: true
tags: ghcr.io/${{ github.repository }}:latest
tags: ghcr.io/${{ github.repository }}:${{ env.DOCKER_IMAGE_TAG }}
build-args: |
PORT=8000

View file

@ -1,15 +1,15 @@
name: release
on:
push:
tags:
- v*
workflow_dispatch:
inputs:
version:
description: 'Version Number'
required: true
type: string
push:
tags:
- v*
jobs:
publish:

1
.gitignore vendored
View file

@ -13,3 +13,4 @@ src/.data
/dist/
/khoj_assistant.egg-info/
/config/khoj*.yml
.pytest_cache

View file

@ -4,7 +4,7 @@ LABEL org.opencontainers.image.source https://github.com/debanjum/khoj
# Install System Dependencies
RUN apt-get update -y && \
apt-get -y install libimage-exiftool-perl
apt-get -y install libimage-exiftool-perl python3-pyqt5
# Copy Application to Container
COPY . /app

View file

@ -2,7 +2,6 @@
[![build](https://github.com/debanjum/khoj/actions/workflows/build.yml/badge.svg)](https://github.com/debanjum/khoj/actions/workflows/build.yml)
[![test](https://github.com/debanjum/khoj/actions/workflows/test.yml/badge.svg)](https://github.com/debanjum/khoj/actions/workflows/test.yml)
[![publish](https://github.com/debanjum/khoj/actions/workflows/publish.yml/badge.svg)](https://github.com/debanjum/khoj/actions/workflows/publish.yml)
[![release](https://github.com/debanjum/khoj/actions/workflows/release.yml/badge.svg)](https://github.com/debanjum/khoj/actions/workflows/release.yml)
*A natural language search engine for your personal notes, transactions and images*
@ -107,7 +106,7 @@ pip install --upgrade khoj-assistant
## Troubleshoot
- Symptom: Errors out complaining about Tensors mismatch, null etc
- Mitigation: Disable `image` search on the desktop GUI
- Mitigation: Disable `image` search using the desktop GUI
- Symptom: Errors out with \"Killed\" in error message in Docker
- Fix: Increase RAM available to Docker Containers in Docker Settings
- Refer: [StackOverflow Solution](https://stackoverflow.com/a/50770267), [Configure Resources on Docker for Mac](https://docs.docker.com/desktop/mac/#resources)
@ -125,14 +124,14 @@ pip install --upgrade khoj-assistant
- Semantic search using the bi-encoder is fairly fast at \<50 ms
- Reranking using the cross-encoder is slower at \<2s on 15 results. Tweak `top_k` to tradeoff speed for accuracy of results
- Applying explicit filters is very slow currently at \~6s. This is because the filters are rudimentary. Considerable speed-ups can be achieved using indexes etc
- Filters in query (e.g by file, word or date) usually add \<20ms to query latency
### Indexing performance
- Indexing is more strongly impacted by the size of the source data
- Indexing 100K+ line corpus of notes takes 6 minutes
- Indexing 100K+ line corpus of notes takes about 10 minutes
- Indexing 4000+ images takes about 15 minutes and more than 8Gb of RAM
- Once <https://github.com/debanjum/khoj/issues/36> is implemented, it should only take this long on first run
- Note: *It should only take this long on the first run* as the index is incrementally updated
### Miscellaneous

View file

@ -28,4 +28,4 @@ services:
- ./tests/data/embeddings/:/data/embeddings/
- ./tests/data/models/:/data/models/
# Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/
command: config/khoj_docker.yml --host="0.0.0.0" --port=8000 -vv
command: --no-gui -c=config/khoj_docker.yml --host="0.0.0.0" --port=8000 -vv

View file

@ -7,7 +7,7 @@ this_directory = Path(__file__).parent
setup(
name='khoj-assistant',
version='0.1.6',
version='0.1.10',
description="A natural language search engine for your personal notes, transactions and images",
long_description=(this_directory / "Readme.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",

View file

@ -1,8 +1,8 @@
# System Packages
import sys
import logging
# External Packages
import torch
import json
# Internal Packages
@ -13,8 +13,14 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.search_type import image_search, text_search
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from src.utils import state
from src.utils.helpers import get_absolute_path
from src.utils.helpers import LRU, resolve_absolute_path
from src.utils.rawconfig import FullConfig, ProcessorConfig
from src.search_filter.date_filter import DateFilter
from src.search_filter.word_filter import WordFilter
from src.search_filter.file_filter import FileFilter
logger = logging.getLogger(__name__)
def configure_server(args, required=False):
@ -28,27 +34,42 @@ def configure_server(args, required=False):
state.config = args.config
# Initialize the search model from Config
state.model = configure_search(state.model, state.config, args.regenerate, verbose=state.verbose)
state.model = configure_search(state.model, state.config, args.regenerate)
# Initialize Processor from Config
state.processor_config = configure_processor(args.config.processor, verbose=state.verbose)
state.processor_config = configure_processor(args.config.processor)
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None, verbose: int = 0):
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None):
# Initialize Org Notes Search
if (t == SearchType.Org or t == None) and config.content_type.org:
# Extract Entries, Generate Notes Embeddings
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
model.orgmode_search = text_search.setup(
org_to_jsonl,
config.content_type.org,
search_config=config.search_type.asymmetric,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()])
# Initialize Org Music Search
if (t == SearchType.Music or t == None) and config.content_type.music:
# Extract Entries, Generate Music Embeddings
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
model.music_search = text_search.setup(
org_to_jsonl,
config.content_type.music,
search_config=config.search_type.asymmetric,
regenerate=regenerate,
filters=[DateFilter(), WordFilter()])
# Initialize Markdown Search
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
# Extract Entries, Generate Markdown Embeddings
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
model.markdown_search = text_search.setup(
markdown_to_jsonl,
config.content_type.markdown,
search_config=config.search_type.asymmetric,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()])
# Initialize Panchayat Search
if (t == SearchType.Panchayat or t == None) and config.content_type.panchayat:
@ -58,17 +79,28 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
# Initialize Ledger Search
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
# Extract Entries, Generate Ledger Embeddings
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose)
model.ledger_search = text_search.setup(
beancount_to_jsonl,
config.content_type.ledger,
search_config=config.search_type.symmetric,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()])
# Initialize Image Search
if (t == SearchType.Image or t == None) and config.content_type.image:
# Extract Entries, Generate Image Embeddings
model.image_search = image_search.setup(config.content_type.image, search_config=config.search_type.image, regenerate=regenerate, verbose=verbose)
model.image_search = image_search.setup(
config.content_type.image,
search_config=config.search_type.image,
regenerate=regenerate)
# Invalidate Query Cache
state.query_cache = LRU()
return model
def configure_processor(processor_config: ProcessorConfig, verbose: int):
def configure_processor(processor_config: ProcessorConfig):
if not processor_config:
return
@ -76,24 +108,20 @@ def configure_processor(processor_config: ProcessorConfig, verbose: int):
# Initialize Conversation Processor
if processor_config.conversation:
processor.conversation = configure_conversation_processor(processor_config.conversation, verbose)
processor.conversation = configure_conversation_processor(processor_config.conversation)
return processor
def configure_conversation_processor(conversation_processor_config, verbose: int):
conversation_processor = ConversationProcessorConfigModel(conversation_processor_config, verbose)
def configure_conversation_processor(conversation_processor_config):
conversation_processor = ConversationProcessorConfigModel(conversation_processor_config)
conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile)
conversation_logfile = conversation_processor.conversation_logfile
if conversation_processor.verbose:
print('INFO:\tLoading conversation logs from disk...')
if conversation_logfile.expanduser().absolute().is_file():
if conversation_logfile.is_file():
# Load Metadata Logs from Conversation Logfile
with open(get_absolute_path(conversation_logfile), 'r') as f:
with conversation_logfile.open('r') as f:
conversation_processor.meta_log = json.load(f)
print('INFO:\tConversation logs loaded from disk.')
logger.info('Conversation logs loaded from disk.')
else:
# Initialize Conversation Logs
conversation_processor.meta_log = {}

View file

@ -92,7 +92,7 @@ class MainWindow(QtWidgets.QMainWindow):
search_type_layout = QtWidgets.QVBoxLayout(search_type_settings)
enable_search_type = SearchCheckBox(f"Search {search_type.name}", search_type)
# Add file browser to set input files for given search type
input_files = FileBrowser(file_input_text, search_type, current_content_files)
input_files = FileBrowser(file_input_text, search_type, current_content_files or [])
# Set enabled/disabled based on checkbox state
enable_search_type.setChecked(current_content_files is not None and len(current_content_files) > 0)

View file

@ -6,9 +6,10 @@ from PyQt6 import QtGui, QtWidgets
# Internal Packages
from src.utils import constants, state
from src.interface.desktop.main_window import MainWindow
def create_system_tray(gui: QtWidgets.QApplication, main_window: QtWidgets.QMainWindow):
def create_system_tray(gui: QtWidgets.QApplication, main_window: MainWindow):
"""Create System Tray with Menu. Menu contain options to
1. Open Search Page on the Web Interface
2. Open App Configuration Screen

View file

@ -5,7 +5,7 @@
;; Author: Debanjum Singh Solanky <debanjum@gmail.com>
;; Description: Natural, Incremental Search for your Second Brain
;; Keywords: search, org-mode, outlines, markdown, beancount, ledger, image
;; Version: 0.1.6
;; Version: 0.1.9
;; Package-Requires: ((emacs "27.1"))
;; URL: http://github.com/debanjum/khoj/interface/emacs

View file

@ -61,12 +61,26 @@
}
function search(rerank=false) {
query = document.getElementById("query").value;
// Extract required fields for search from form
query = document.getElementById("query").value.trim();
type = document.getElementById("type").value;
console.log(query, type);
results_count = document.getElementById("results-count").value || 6;
console.log(`Query: ${query}, Type: ${type}`);
// Short circuit on empty query
if (query.length === 0)
return;
// If set query field in url query param on rerank
if (rerank)
setQueryFieldInUrl(query);
// Generate Backend API URL to execute Search
url = type === "image"
? `/search?q=${query}&t=${type}&n=6`
: `/search?q=${query}&t=${type}&n=6&r=${rerank}`;
? `/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}`
: `/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&r=${rerank}`;
// Execute Search and Render Results
fetch(url)
.then(response => response.json())
.then(data => {
@ -78,9 +92,9 @@
});
}
function regenerate() {
function updateIndex() {
type = document.getElementById("type").value;
fetch(`/regenerate?t=${type}`)
fetch(`/reload?t=${type}`)
.then(response => response.json())
.then(data => {
console.log(data);
@ -89,7 +103,7 @@
});
}
function incremental_search(event) {
function incrementalSearch(event) {
type = document.getElementById("type").value;
// Search with reranking on 'Enter'
if (event.key === 'Enter') {
@ -121,10 +135,33 @@
});
}
function setTypeFieldInUrl(type) {
var url = new URL(window.location.href);
url.searchParams.set("t", type.value);
window.history.pushState({}, "", url.href);
}
function setCountFieldInUrl(results_count) {
var url = new URL(window.location.href);
url.searchParams.set("n", results_count.value);
window.history.pushState({}, "", url.href);
}
function setQueryFieldInUrl(query) {
var url = new URL(window.location.href);
url.searchParams.set("q", query);
window.history.pushState({}, "", url.href);
}
window.onload = function () {
// Dynamically populate type dropdown based on enabled search types and type passed as URL query parameter
populate_type_dropdown();
// Set results count field with value passed in URL query parameters, if any.
var results_count = new URLSearchParams(window.location.search).get("n");
if (results_count)
document.getElementById("results-count").value = results_count;
// Fill query field with value passed in URL query parameters, if any.
var query_via_url = new URLSearchParams(window.location.search).get("q");
if (query_via_url)
@ -136,14 +173,17 @@
<h1>Khoj</h1>
<!--Add Text Box To Enter Query, Trigger Incremental Search OnChange -->
<input type="text" id="query" onkeyup=incremental_search(event) autofocus="autofocus" placeholder="What is the meaning of life?">
<input type="text" id="query" onkeyup=incrementalSearch(event) autofocus="autofocus" placeholder="What is the meaning of life?">
<div id="options">
<!--Add Dropdown to Select Query Type -->
<select id="type"></select>
<select id="type" onchange="setTypeFieldInUrl(this)"></select>
<!--Add Button To Regenerate -->
<button id="regenerate" onclick="regenerate()">Regenerate</button>
<button id="update" onclick="updateIndex()">Update</button>
<!--Add Results Count Input To Set Results Count -->
<input type="number" id="results-count" min="1" max="100" value="6" placeholder="results count" onchange="setCountFieldInUrl(this)">
</div>
<!-- Section to Render Results -->
@ -194,7 +234,7 @@
#options {
padding: 0;
display: grid;
grid-template-columns: 1fr 1fr;
grid-template-columns: 1fr 1fr minmax(70px, 0.5fr);
}
#options > * {
padding: 15px;
@ -202,10 +242,10 @@
border: 1px solid #ccc;
}
#options > select {
margin-right: 5px;
margin-right: 10px;
}
#options > button {
margin-left: 5px;
margin-right: 10px;
}
#query {

View file

@ -2,8 +2,13 @@
import os
import signal
import sys
import logging
import warnings
from platform import system
# Ignore non-actionable warnings
warnings.filterwarnings("ignore", message=r'snapshot_download.py has been made private', category=FutureWarning)
# External Packages
import uvicorn
from fastapi import FastAPI
@ -25,6 +30,34 @@ app = FastAPI()
app.mount("/static", StaticFiles(directory=constants.web_directory), name="static")
app.include_router(router)
logger = logging.getLogger('src')
class CustomFormatter(logging.Formatter):
blue = "\x1b[1;34m"
green = "\x1b[1;32m"
grey = "\x1b[38;20m"
yellow = "\x1b[33;20m"
red = "\x1b[31;20m"
bold_red = "\x1b[31;1m"
reset = "\x1b[0m"
format_str = "%(levelname)s: %(asctime)s: %(name)s | %(message)s"
FORMATS = {
logging.DEBUG: blue + format_str + reset,
logging.INFO: green + format_str + reset,
logging.WARNING: yellow + format_str + reset,
logging.ERROR: red + format_str + reset,
logging.CRITICAL: bold_red + format_str + reset
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt)
return formatter.format(record)
def run():
# Turn Tokenizers Parallelism Off. App does not support it.
os.environ["TOKENIZERS_PARALLELISM"] = 'false'
@ -34,6 +67,29 @@ def run():
args = cli(state.cli_args)
set_state(args)
# Create app directory, if it doesn't exist
state.config_file.parent.mkdir(parents=True, exist_ok=True)
# Setup Logger
if args.verbose == 0:
logger.setLevel(logging.WARN)
elif args.verbose == 1:
logger.setLevel(logging.INFO)
elif args.verbose >= 2:
logger.setLevel(logging.DEBUG)
# Set Log Format
ch = logging.StreamHandler()
ch.setFormatter(CustomFormatter())
logger.addHandler(ch)
# Set Log File
fh = logging.FileHandler(state.config_file.parent / 'khoj.log')
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)
logger.info("Starting Khoj...")
if args.no_gui:
# Start Server
configure_server(args, required=True)

View file

@ -2,108 +2,127 @@
# Standard Packages
import json
import argparse
import pathlib
import glob
import re
import logging
import time
# Internal Packages
from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update
from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import TextContentConfig
logger = logging.getLogger(__name__)
# Define Functions
def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file, verbose=0):
def beancount_to_jsonl(config: TextContentConfig, previous_entries=None):
# Extract required fields from config
beancount_files, beancount_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl
# Input Validation
if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter):
print("At least one of beancount-files or beancount-file-filter is required to be specified")
exit(1)
# Get Beancount Files to Process
beancount_files = get_beancount_files(beancount_files, beancount_file_filter, verbose)
beancount_files = get_beancount_files(beancount_files, beancount_file_filter)
# Extract Entries from specified Beancount files
entries = extract_beancount_entries(beancount_files)
start = time.time()
current_entries = convert_transactions_to_maps(*extract_beancount_transactions(beancount_files))
end = time.time()
logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries
start = time.time()
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
end = time.time()
logger.debug(f"Identify new or updated transaction: {end - start} seconds")
# Process Each Entry from All Notes Files
jsonl_data = convert_beancount_entries_to_jsonl(entries, verbose=verbose)
start = time.time()
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = convert_transaction_maps_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file, verbose=verbose)
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file, verbose=verbose)
dump_jsonl(jsonl_data, output_file)
end = time.time()
logger.debug(f"Write transactions to JSONL file: {end - start} seconds")
return entries
return entries_with_ids
def get_beancount_files(beancount_files=None, beancount_file_filter=None, verbose=0):
def get_beancount_files(beancount_files=None, beancount_file_filters=None):
"Get Beancount files to process"
absolute_beancount_files, filtered_beancount_files = set(), set()
if beancount_files:
absolute_beancount_files = {get_absolute_path(beancount_file)
for beancount_file
in beancount_files}
if beancount_file_filter:
filtered_beancount_files = set(glob.glob(get_absolute_path(beancount_file_filter)))
if beancount_file_filters:
filtered_beancount_files = {
filtered_file
for beancount_file_filter in beancount_file_filters
for filtered_file in glob.glob(get_absolute_path(beancount_file_filter))
}
all_beancount_files = absolute_beancount_files | filtered_beancount_files
all_beancount_files = sorted(absolute_beancount_files | filtered_beancount_files)
files_with_non_beancount_extensions = {beancount_file
files_with_non_beancount_extensions = {
beancount_file
for beancount_file
in all_beancount_files
if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount")}
if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount")
}
if any(files_with_non_beancount_extensions):
print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}")
if verbose > 0:
print(f'Processing files: {all_beancount_files}')
logger.info(f'Processing files: {all_beancount_files}')
return all_beancount_files
def extract_beancount_entries(beancount_files):
def extract_beancount_transactions(beancount_files):
"Extract entries from specified Beancount files"
# Initialize Regex for extracting Beancount Entries
transaction_regex = r'^\n?\d{4}-\d{2}-\d{2} [\*|\!] '
empty_newline = f'^[{empty_escape_sequences}]*$'
empty_newline = f'^[\n\r\t\ ]*$'
entries = []
transaction_to_file_map = []
for beancount_file in beancount_files:
with open(beancount_file) as f:
ledger_content = f.read()
entries.extend([entry.strip(empty_escape_sequences)
transactions_per_file = [entry.strip(empty_escape_sequences)
for entry
in re.split(empty_newline, ledger_content, flags=re.MULTILINE)
if re.match(transaction_regex, entry)])
return entries
if re.match(transaction_regex, entry)]
transaction_to_file_map += zip(transactions_per_file, [beancount_file]*len(transactions_per_file))
entries.extend(transactions_per_file)
return entries, dict(transaction_to_file_map)
def convert_beancount_entries_to_jsonl(entries, verbose=0):
"Convert each Beancount transaction to JSON and collate as JSONL"
jsonl = ''
def convert_transactions_to_maps(entries: list[str], transaction_to_file_map) -> list[dict]:
"Convert each Beancount transaction into a dictionary"
entry_maps = []
for entry in entries:
entry_dict = {'compiled': entry, 'raw': entry}
# Convert Dictionary to JSON and Append to JSONL string
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{transaction_to_file_map[entry]}'})
if verbose > 0:
print(f"Converted {len(entries)} to jsonl format")
logger.info(f"Converted {len(entries)} transactions to dictionaries")
return jsonl
return entry_maps
if __name__ == '__main__':
# Setup Argument Parser
parser = argparse.ArgumentParser(description="Map Beancount transactions into (compressed) JSONL format")
parser.add_argument('--output-file', '-o', type=pathlib.Path, required=True, help="Output file for (compressed) JSONL formatted transactions. Expected file extensions: jsonl or jsonl.gz")
parser.add_argument('--input-files', '-i', nargs='*', help="List of beancount files to process")
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for beancount files to process")
parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs, Default: 0")
args = parser.parse_args()
# Map transactions in beancount files to (compressed) JSONL formatted file
beancount_to_jsonl(args.input_files, args.input_filter, args.output_file, args.verbose)
def convert_transaction_maps_to_jsonl(entries: list[dict]) -> str:
"Convert each Beancount transaction dictionary to JSON and collate as JSONL"
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries])

View file

@ -2,51 +2,78 @@
# Standard Packages
import json
import argparse
import pathlib
import glob
import re
import logging
import time
# Internal Packages
from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update
from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import TextContentConfig
logger = logging.getLogger(__name__)
# Define Functions
def markdown_to_jsonl(markdown_files, markdown_file_filter, output_file, verbose=0):
def markdown_to_jsonl(config: TextContentConfig, previous_entries=None):
# Extract required fields from config
markdown_files, markdown_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl
# Input Validation
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
print("At least one of markdown-files or markdown-file-filter is required to be specified")
exit(1)
# Get Markdown Files to Process
markdown_files = get_markdown_files(markdown_files, markdown_file_filter, verbose)
markdown_files = get_markdown_files(markdown_files, markdown_file_filter)
# Extract Entries from specified Markdown files
entries = extract_markdown_entries(markdown_files)
start = time.time()
current_entries = convert_markdown_entries_to_maps(*extract_markdown_entries(markdown_files))
end = time.time()
logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries
start = time.time()
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
end = time.time()
logger.debug(f"Identify new or updated entries: {end - start} seconds")
# Process Each Entry from All Notes Files
jsonl_data = convert_markdown_entries_to_jsonl(entries, verbose=verbose)
start = time.time()
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = convert_markdown_maps_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file, verbose=verbose)
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file, verbose=verbose)
dump_jsonl(jsonl_data, output_file)
end = time.time()
logger.debug(f"Write markdown entries to JSONL file: {end - start} seconds")
return entries
return entries_with_ids
def get_markdown_files(markdown_files=None, markdown_file_filter=None, verbose=0):
def get_markdown_files(markdown_files=None, markdown_file_filters=None):
"Get Markdown files to process"
absolute_markdown_files, filtered_markdown_files = set(), set()
if markdown_files:
absolute_markdown_files = {get_absolute_path(markdown_file) for markdown_file in markdown_files}
if markdown_file_filter:
filtered_markdown_files = set(glob.glob(get_absolute_path(markdown_file_filter)))
if markdown_file_filters:
filtered_markdown_files = {
filtered_file
for markdown_file_filter in markdown_file_filters
for filtered_file in glob.glob(get_absolute_path(markdown_file_filter))
}
all_markdown_files = absolute_markdown_files | filtered_markdown_files
all_markdown_files = sorted(absolute_markdown_files | filtered_markdown_files)
files_with_non_markdown_extensions = {
md_file
@ -56,10 +83,9 @@ def get_markdown_files(markdown_files=None, markdown_file_filter=None, verbose=0
}
if any(files_with_non_markdown_extensions):
print(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}")
logger.warn(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}")
if verbose > 0:
print(f'Processing files: {all_markdown_files}')
logger.info(f'Processing files: {all_markdown_files}')
return all_markdown_files
@ -71,38 +97,31 @@ def extract_markdown_entries(markdown_files):
markdown_heading_regex = r'^#'
entries = []
entry_to_file_map = []
for markdown_file in markdown_files:
with open(markdown_file) as f:
markdown_content = f.read()
entries.extend([f'#{entry.strip(empty_escape_sequences)}'
markdown_entries_per_file = [f'#{entry.strip(empty_escape_sequences)}'
for entry
in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE)])
in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE)
if entry.strip(empty_escape_sequences) != '']
entry_to_file_map += zip(markdown_entries_per_file, [markdown_file]*len(markdown_entries_per_file))
entries.extend(markdown_entries_per_file)
return entries
return entries, dict(entry_to_file_map)
def convert_markdown_entries_to_jsonl(entries, verbose=0):
"Convert each Markdown entries to JSON and collate as JSONL"
jsonl = ''
def convert_markdown_entries_to_maps(entries: list[str], entry_to_file_map) -> list[dict]:
"Convert each Markdown entries into a dictionary"
entry_maps = []
for entry in entries:
entry_dict = {'compiled': entry, 'raw': entry}
# Convert Dictionary to JSON and Append to JSONL string
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{entry_to_file_map[entry]}'})
if verbose > 0:
print(f"Converted {len(entries)} to jsonl format")
logger.info(f"Converted {len(entries)} markdown entries to dictionaries")
return jsonl
return entry_maps
if __name__ == '__main__':
# Setup Argument Parser
parser = argparse.ArgumentParser(description="Map Markdown entries into (compressed) JSONL format")
parser.add_argument('--output-file', '-o', type=pathlib.Path, required=True, help="Output file for (compressed) JSONL formatted notes. Expected file extensions: jsonl or jsonl.gz")
parser.add_argument('--input-files', '-i', nargs='*', help="List of markdown files to process")
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for markdown files to process")
parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs, Default: 0")
args = parser.parse_args()
# Map notes in Markdown files to (compressed) JSONL formatted file
markdown_to_jsonl(args.input_files, args.input_filter, args.output_file, args.verbose)
def convert_markdown_maps_to_jsonl(entries):
"Convert each Markdown entries to JSON and collate as JSONL"
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries])

View file

@ -1,62 +1,94 @@
#!/usr/bin/env python3
# Standard Packages
import re
import json
import argparse
import pathlib
import glob
import logging
import time
from typing import Iterable
# Internal Packages
from src.processor.org_mode import orgnode
from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.constants import empty_escape_sequences
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils import state
from src.utils.rawconfig import TextContentConfig
logger = logging.getLogger(__name__)
# Define Functions
def org_to_jsonl(org_files, org_file_filter, output_file, verbose=0):
def org_to_jsonl(config: TextContentConfig, previous_entries=None):
# Extract required fields from config
org_files, org_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl
index_heading_entries = config.index_heading_entries
# Input Validation
if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter):
print("At least one of org-files or org-file-filter is required to be specified")
exit(1)
# Get Org Files to Process
org_files = get_org_files(org_files, org_file_filter, verbose)
start = time.time()
org_files = get_org_files(org_files, org_file_filter)
# Extract Entries from specified Org files
entries = extract_org_entries(org_files)
start = time.time()
entry_nodes, file_to_entries = extract_org_entries(org_files)
end = time.time()
logger.debug(f"Parse entries from org files into OrgNode objects: {end - start} seconds")
start = time.time()
current_entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
end = time.time()
logger.debug(f"Convert OrgNodes into entry dictionaries: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
# Process Each Entry from All Notes Files
jsonl_data = convert_org_entries_to_jsonl(entries, verbose=verbose)
start = time.time()
entries = map(lambda entry: entry[1], entries_with_ids)
jsonl_data = convert_org_entries_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file, verbose=verbose)
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file, verbose=verbose)
dump_jsonl(jsonl_data, output_file)
end = time.time()
logger.debug(f"Write org entries to JSONL file: {end - start} seconds")
return entries
return entries_with_ids
def get_org_files(org_files=None, org_file_filter=None, verbose=0):
def get_org_files(org_files=None, org_file_filters=None):
"Get Org files to process"
absolute_org_files, filtered_org_files = set(), set()
if org_files:
absolute_org_files = {get_absolute_path(org_file)
absolute_org_files = {
get_absolute_path(org_file)
for org_file
in org_files}
if org_file_filter:
filtered_org_files = set(glob.glob(get_absolute_path(org_file_filter)))
in org_files
}
if org_file_filters:
filtered_org_files = {
filtered_file
for org_file_filter in org_file_filters
for filtered_file in glob.glob(get_absolute_path(org_file_filter))
}
all_org_files = absolute_org_files | filtered_org_files
all_org_files = sorted(absolute_org_files | filtered_org_files)
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
if any(files_with_non_org_extensions):
print(f"[Warning] There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
if verbose > 0:
print(f'Processing files: {all_org_files}')
logger.info(f'Processing files: {all_org_files}')
return all_org_files
@ -64,69 +96,60 @@ def get_org_files(org_files=None, org_file_filter=None, verbose=0):
def extract_org_entries(org_files):
"Extract entries from specified Org files"
entries = []
entry_to_file_map = []
for org_file in org_files:
entries.extend(
orgnode.makelist(
str(org_file)))
org_file_entries = orgnode.makelist(str(org_file))
entry_to_file_map += zip(org_file_entries, [org_file]*len(org_file_entries))
entries.extend(org_file_entries)
return entries
return entries, dict(entry_to_file_map)
def convert_org_entries_to_jsonl(entries, verbose=0) -> str:
"Convert each Org-Mode entries to JSON and collate as JSONL"
jsonl = ''
def convert_org_nodes_to_entries(entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[dict]:
"Convert Org-Mode entries into list of dictionary"
entry_maps = []
for entry in entries:
entry_dict = dict()
if not entry.hasBody and not index_heading_entries:
# Ignore title notes i.e notes with just headings and empty body
if not entry.Body() or re.sub(r'\n|\t|\r| ', '', entry.Body()) == "":
continue
entry_dict["compiled"] = f'{entry.Heading()}.'
if verbose > 2:
print(f"Title: {entry.Heading()}")
entry_dict["compiled"] = f'{entry.heading}.'
if state.verbose > 2:
logger.debug(f"Title: {entry.heading}")
if entry.Tags():
tags_str = " ".join(entry.Tags())
if entry.tags:
tags_str = " ".join(entry.tags)
entry_dict["compiled"] += f'\t {tags_str}.'
if verbose > 2:
print(f"Tags: {tags_str}")
if state.verbose > 2:
logger.debug(f"Tags: {tags_str}")
if entry.Closed():
entry_dict["compiled"] += f'\n Closed on {entry.Closed().strftime("%Y-%m-%d")}.'
if verbose > 2:
print(f'Closed: {entry.Closed().strftime("%Y-%m-%d")}')
if entry.closed:
entry_dict["compiled"] += f'\n Closed on {entry.closed.strftime("%Y-%m-%d")}.'
if state.verbose > 2:
logger.debug(f'Closed: {entry.closed.strftime("%Y-%m-%d")}')
if entry.Scheduled():
entry_dict["compiled"] += f'\n Scheduled for {entry.Scheduled().strftime("%Y-%m-%d")}.'
if verbose > 2:
print(f'Scheduled: {entry.Scheduled().strftime("%Y-%m-%d")}')
if entry.scheduled:
entry_dict["compiled"] += f'\n Scheduled for {entry.scheduled.strftime("%Y-%m-%d")}.'
if state.verbose > 2:
logger.debug(f'Scheduled: {entry.scheduled.strftime("%Y-%m-%d")}')
if entry.Body():
entry_dict["compiled"] += f'\n {entry.Body()}'
if verbose > 2:
print(f"Body: {entry.Body()}")
if entry.hasBody:
entry_dict["compiled"] += f'\n {entry.body}'
if state.verbose > 2:
logger.debug(f"Body: {entry.body}")
if entry_dict:
entry_dict["raw"] = f'{entry}'
entry_dict["file"] = f'{entry_to_file_map[entry]}'
# Convert Dictionary to JSON and Append to JSONL string
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
entry_maps.append(entry_dict)
if verbose > 0:
print(f"Converted {len(entries)} to jsonl format")
return jsonl
return entry_maps
if __name__ == '__main__':
# Setup Argument Parser
parser = argparse.ArgumentParser(description="Map Org-Mode notes into (compressed) JSONL format")
parser.add_argument('--output-file', '-o', type=pathlib.Path, required=True, help="Output file for (compressed) JSONL formatted notes. Expected file extensions: jsonl or jsonl.gz")
parser.add_argument('--input-files', '-i', nargs='*', help="List of org-mode files to process")
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for org-mode files to process")
parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs, Default: 0")
args = parser.parse_args()
# Map notes in Org-Mode files to (compressed) JSONL formatted file
org_to_jsonl(args.input_files, args.input_filter, args.output_file, args.verbose)
def convert_org_entries_to_jsonl(entries: Iterable[dict]) -> str:
"Convert each Org-Mode entry to JSON and collate as JSONL"
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries])

View file

@ -33,16 +33,21 @@ headline and associated text from an org-mode file, and routines for
constructing data structures of these classes.
"""
import re, sys
import re
import datetime
from pathlib import Path
from os.path import relpath
indent_regex = re.compile(r'^\s*')
indent_regex = re.compile(r'^ *')
def normalize_filename(filename):
file_relative_to_home = f'~/{relpath(filename, start=Path.home())}'
escaped_filename = f'{file_relative_to_home}'.replace("[","\[").replace("]","\]")
"Normalize and escape filename for rendering"
if not Path(filename).is_absolute():
# Normalize relative filename to be relative to current directory
normalized_filename = f'~/{relpath(filename, start=Path.home())}'
else:
normalized_filename = filename
escaped_filename = f'{normalized_filename}'.replace("[","\[").replace("]","\]")
return escaped_filename
def makelist(filename):
@ -52,65 +57,71 @@ def makelist(filename):
"""
ctr = 0
try:
f = open(filename, 'r')
except IOError:
print(f"Unable to open file {filename}")
print("Program terminating.")
sys.exit(1)
todos = { "TODO": "", "WAITING": "", "ACTIVE": "",
"DONE": "", "CANCELLED": "", "FAILED": ""} # populated from #+SEQ_TODO line
level = 0
level = ""
heading = ""
bodytext = ""
tags = set() # set of all tags in headline
tags = list() # set of all tags in headline
closed_date = ''
sched_date = ''
deadline_date = ''
logbook = list()
nodelist = []
propdict = dict()
nodelist: list[Orgnode] = list()
property_map = dict()
in_properties_drawer = False
in_logbook_drawer = False
file_title = f'{filename}'
for line in f:
ctr += 1
hdng = re.search(r'^(\*+)\s(.*?)\s*$', line)
if hdng: # we are processing a heading line
heading_search = re.search(r'^(\*+)\s(.*?)\s*$', line)
if heading_search: # we are processing a heading line
if heading: # if we have are on second heading, append first heading to headings list
thisNode = Orgnode(level, heading, bodytext, tags)
if closed_date:
thisNode.setClosed(closed_date)
thisNode.closed = closed_date
closed_date = ''
if sched_date:
thisNode.setScheduled(sched_date)
thisNode.scheduled = sched_date
sched_date = ""
if deadline_date:
thisNode.setDeadline(deadline_date)
thisNode.deadline = deadline_date
deadline_date = ''
if logbook:
thisNode.setLogbook(logbook)
thisNode.logbook = logbook
logbook = list()
thisNode.setProperties(propdict)
thisNode.properties = property_map
nodelist.append( thisNode )
propdict = {'LINE': f'file:{normalize_filename(filename)}::{ctr}'}
level = hdng.group(1)
heading = hdng.group(2)
property_map = {'LINE': f'file:{normalize_filename(filename)}::{ctr}'}
level = heading_search.group(1)
heading = heading_search.group(2)
bodytext = ""
tags = set() # set of all tags in headline
tagsrch = re.search(r'(.*?)\s*:([a-zA-Z0-9].*?):$',heading)
if tagsrch:
heading = tagsrch.group(1)
parsedtags = tagsrch.group(2)
tags = list() # set of all tags in headline
tag_search = re.search(r'(.*?)\s*:([a-zA-Z0-9].*?):$',heading)
if tag_search:
heading = tag_search.group(1)
parsedtags = tag_search.group(2)
if parsedtags:
for parsedtag in parsedtags.split(':'):
if parsedtag != '': tags.add(parsedtag)
if parsedtag != '': tags.append(parsedtag)
else: # we are processing a non-heading line
if line[:10] == '#+SEQ_TODO':
kwlist = re.findall(r'([A-Z]+)\(', line)
for kw in kwlist: todos[kw] = ""
# Set file title to TITLE property, if it exists
title_search = re.search(r'^#\+TITLE:\s*(.*)$', line)
if title_search and title_search.group(1).strip() != '':
title_text = title_search.group(1).strip()
if file_title == f'{filename}':
file_title = title_text
else:
file_title += f' {title_text}'
continue
# Ignore Properties Drawers Completely
if re.search(':PROPERTIES:', line):
in_properties_drawer=True
@ -137,13 +148,13 @@ def makelist(filename):
logbook += [(clocked_in, clocked_out)]
line = ""
prop_srch = re.search(r'^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$', line)
if prop_srch:
property_search = re.search(r'^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$', line)
if property_search:
# Set ID property to an id based org-mode link to the entry
if prop_srch.group(1) == 'ID':
propdict['ID'] = f'id:{prop_srch.group(2)}'
if property_search.group(1) == 'ID':
property_map['ID'] = f'id:{property_search.group(2)}'
else:
propdict[prop_srch.group(1)] = prop_srch.group(2)
property_map[property_search.group(1)] = property_search.group(2)
continue
cd_re = re.search(r'CLOSED:\s*\[([0-9]{4})-([0-9]{2})-([0-9]{2})', line)
@ -167,36 +178,39 @@ def makelist(filename):
bodytext = bodytext + line
# write out last node
thisNode = Orgnode(level, heading, bodytext, tags)
thisNode.setProperties(propdict)
thisNode = Orgnode(level, heading or file_title, bodytext, tags)
thisNode.properties = property_map
if sched_date:
thisNode.setScheduled(sched_date)
thisNode.scheduled = sched_date
if deadline_date:
thisNode.setDeadline(deadline_date)
thisNode.deadline = deadline_date
if closed_date:
thisNode.setClosed(closed_date)
thisNode.closed = closed_date
if logbook:
thisNode.setLogbook(logbook)
thisNode.logbook = logbook
nodelist.append( thisNode )
# using the list of TODO keywords found in the file
# process the headings searching for TODO keywords
for n in nodelist:
h = n.Heading()
todoSrch = re.search(r'([A-Z]+)\s(.*?)$', h)
if todoSrch:
if todoSrch.group(1) in todos:
n.setHeading( todoSrch.group(2) )
n.setTodo ( todoSrch.group(1) )
todo_search = re.search(r'([A-Z]+)\s(.*?)$', n.heading)
if todo_search:
if todo_search.group(1) in todos:
n.heading = todo_search.group(2)
n.todo = todo_search.group(1)
# extract, set priority from heading, update heading if necessary
prtysrch = re.search(r'^\[\#(A|B|C)\] (.*?)$', n.Heading())
if prtysrch:
n.setPriority(prtysrch.group(1))
n.setHeading(prtysrch.group(2))
priority_search = re.search(r'^\[\#(A|B|C)\] (.*?)$', n.heading)
if priority_search:
n.priority = priority_search.group(1)
n.heading = priority_search.group(2)
# Set SOURCE property to a file+heading based org-mode link to the entry
escaped_heading = n.Heading().replace("[","\\[").replace("]","\\]")
if n.level == 0:
n.properties['LINE'] = f'file:{normalize_filename(filename)}::0'
n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}]]'
else:
escaped_heading = n.heading.replace("[","\\[").replace("]","\\]")
n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}::*{escaped_heading}]]'
return nodelist
@ -214,199 +228,234 @@ class Orgnode(object):
first tag. The makelist routine postprocesses the list to
identify TODO tags and updates headline and todo fields.
"""
self.level = len(level)
self.headline = headline
self.body = body
self.tags = set(tags) # All tags in the headline
self.todo = ""
self.prty = "" # empty of A, B or C
self.scheduled = "" # Scheduled date
self.deadline = "" # Deadline date
self.closed = "" # Closed date
self.properties = dict()
self.logbook = list() # List of clock-in, clock-out tuples representing logbook entries
self._level = len(level)
self._heading = headline
self._body = body
self._tags = tags # All tags in the headline
self._todo = ""
self._priority = "" # empty of A, B or C
self._scheduled = "" # Scheduled date
self._deadline = "" # Deadline date
self._closed = "" # Closed date
self._properties = dict()
self._logbook = list() # List of clock-in, clock-out tuples representing logbook entries
# Look for priority in headline and transfer to prty field
def Heading(self):
@property
def heading(self):
"""
Return the Heading text of the node without the TODO tag
"""
return self.headline
return self._heading
def setHeading(self, newhdng):
@heading.setter
def heading(self, newhdng):
"""
Change the heading to the supplied string
"""
self.headline = newhdng
self._heading = newhdng
def Body(self):
@property
def body(self):
"""
Returns all lines of text of the body of this node except the
Property Drawer
"""
return self.body
return self._body
def Level(self):
@property
def hasBody(self):
"""
Returns True if node has non empty body, else False
"""
return self._body and re.sub(r'\n|\t|\r| ', '', self._body) != ''
@property
def level(self):
"""
Returns an integer corresponding to the level of the node.
Top level (one asterisk) has a level of 1.
"""
return self.level
return self._level
def Priority(self):
@property
def priority(self):
"""
Returns the priority of this headline: 'A', 'B', 'C' or empty
string if priority has not been set.
"""
return self.prty
return self._priority
def setPriority(self, newprty):
@priority.setter
def priority(self, new_priority):
"""
Change the value of the priority of this headline.
Values values are '', 'A', 'B', 'C'
"""
self.prty = newprty
self._priority = new_priority
def Tags(self):
@property
def tags(self):
"""
Returns the set of all tags
For example, :HOME:COMPUTER: would return {'HOME', 'COMPUTER'}
Returns the list of all tags
For example, :HOME:COMPUTER: would return ['HOME', 'COMPUTER']
"""
return self.tags
return self._tags
def hasTag(self, srch):
@tags.setter
def tags(self, newtags):
"""
Store all the tags found in the headline.
"""
self._tags = newtags
def hasTag(self, tag):
"""
Returns True if the supplied tag is present in this headline
For example, hasTag('COMPUTER') on headling containing
:HOME:COMPUTER: would return True.
"""
return srch in self.tags
return tag in self._tags
def setTags(self, newtags):
"""
Store all the tags found in the headline.
"""
self.tags = set(newtags)
def Todo(self):
@property
def todo(self):
"""
Return the value of the TODO tag
"""
return self.todo
return self._todo
def setTodo(self, value):
@todo.setter
def todo(self, new_todo):
"""
Set the value of the TODO tag to the supplied string
"""
self.todo = value
self._todo = new_todo
def setProperties(self, dictval):
@property
def properties(self):
"""
Return the dictionary of properties
"""
return self._properties
@properties.setter
def properties(self, new_properties):
"""
Sets all properties using the supplied dictionary of
name/value pairs
"""
self.properties = dictval
self._properties = new_properties
def Property(self, keyval):
def Property(self, property_key):
"""
Returns the value of the requested property or null if the
property does not exist.
"""
return self.properties.get(keyval, "")
return self._properties.get(property_key, "")
def setScheduled(self, dateval):
@property
def scheduled(self):
"""
Set the scheduled date using the supplied date object
Return the scheduled date
"""
self.scheduled = dateval
return self._scheduled
def Scheduled(self):
@scheduled.setter
def scheduled(self, new_scheduled):
"""
Return the scheduled date object or null if nonexistent
Set the scheduled date to the scheduled date
"""
return self.scheduled
self._scheduled = new_scheduled
def setDeadline(self, dateval):
@property
def deadline(self):
"""
Set the deadline (due) date using the supplied date object
Return the deadline date
"""
self.deadline = dateval
return self._deadline
def Deadline(self):
@deadline.setter
def deadline(self, new_deadline):
"""
Return the deadline date object or null if nonexistent
Set the deadline (due) date to the new deadline date
"""
return self.deadline
self._deadline = new_deadline
def setClosed(self, dateval):
@property
def closed(self):
"""
Set the closed date using the supplied date object
Return the closed date
"""
self.closed = dateval
return self._closed
def Closed(self):
@closed.setter
def closed(self, new_closed):
"""
Return the closed date object or null if nonexistent
Set the closed date to the new closed date
"""
return self.closed
self._closed = new_closed
def setLogbook(self, logbook):
"""
Set the logbook with list of clocked-in, clocked-out tuples for the entry
"""
self.logbook = logbook
def Logbook(self):
@property
def logbook(self):
"""
Return the logbook with all clocked-in, clocked-out date object pairs or empty list if nonexistent
"""
return self.logbook
return self._logbook
@logbook.setter
def logbook(self, new_logbook):
"""
Set the logbook with list of clocked-in, clocked-out tuples for the entry
"""
self._logbook = new_logbook
def __repr__(self):
"""
Print the level, heading text and tag of a node and the body
text as used to construct the node.
"""
# This method is not completed yet.
# Output heading line
n = ''
for _ in range(0, self.level):
for _ in range(0, self._level):
n = n + '*'
n = n + ' '
if self.todo:
n = n + self.todo + ' '
if self.prty:
n = n + '[#' + self.prty + '] '
n = n + self.headline
if self._todo:
n = n + self._todo + ' '
if self._priority:
n = n + '[#' + self._priority + '] '
n = n + self._heading
n = "%-60s " % n # hack - tags will start in column 62
closecolon = ''
for t in self.tags:
for t in self._tags:
n = n + ':' + t
closecolon = ':'
n = n + closecolon
n = n + "\n"
# Get body indentation from first line of body
indent = indent_regex.match(self.body).group()
indent = indent_regex.match(self._body).group()
# Output Closed Date, Scheduled Date, Deadline Date
if self.closed or self.scheduled or self.deadline:
if self._closed or self._scheduled or self._deadline:
n = n + indent
if self.closed:
n = n + f'CLOSED: [{self.closed.strftime("%Y-%m-%d %a")}] '
if self.scheduled:
n = n + f'SCHEDULED: <{self.scheduled.strftime("%Y-%m-%d %a")}> '
if self.deadline:
n = n + f'DEADLINE: <{self.deadline.strftime("%Y-%m-%d %a")}> '
if self.closed or self.scheduled or self.deadline:
if self._closed:
n = n + f'CLOSED: [{self._closed.strftime("%Y-%m-%d %a")}] '
if self._scheduled:
n = n + f'SCHEDULED: <{self._scheduled.strftime("%Y-%m-%d %a")}> '
if self._deadline:
n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> '
if self._closed or self._scheduled or self._deadline:
n = n + '\n'
# Ouput Property Drawer
n = n + indent + ":PROPERTIES:\n"
for key, value in self.properties.items():
for key, value in self._properties.items():
n = n + indent + f":{key}: {value}\n"
n = n + indent + ":END:\n"
n = n + self.body
# Output Body
if self.hasBody:
n = n + self._body
return n

View file

@ -2,8 +2,8 @@
import yaml
import json
import time
import logging
from typing import Optional
from functools import lru_cache
# External Packages
from fastapi import APIRouter
@ -15,16 +15,17 @@ from fastapi.templating import Jinja2Templates
from src.configure import configure_search
from src.search_type import image_search, text_search
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
from src.search_filter.explicit_filter import ExplicitFilter
from src.search_filter.date_filter import DateFilter
from src.utils.rawconfig import FullConfig
from src.utils.config import SearchType
from src.utils.helpers import get_absolute_path, get_from_dict
from src.utils.helpers import LRU, get_absolute_path, get_from_dict
from src.utils import state, constants
router = APIRouter()
router = APIRouter()
templates = Jinja2Templates(directory=constants.web_directory)
logger = logging.getLogger(__name__)
query_cache = LRU()
@router.get("/", response_class=FileResponse)
def index():
@ -47,22 +48,27 @@ async def config_data(updated_config: FullConfig):
return state.config
@router.get('/search')
@lru_cache(maxsize=100)
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
if q is None or q == '':
print(f'No query param (q) passed in API call to initiate search')
logger.info(f'No query param (q) passed in API call to initiate search')
return {}
# initialize variables
user_query = q
user_query = q.strip()
results_count = n
results = {}
query_start, query_end, collate_start, collate_end = None, None, None, None
# return cached results, if available
query_cache_key = f'{user_query}-{n}-{t}-{r}'
if query_cache_key in state.query_cache:
logger.info(f'Return response from query cache')
return state.query_cache[query_cache_key]
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
# query org-mode notes
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
query_end = time.time()
# collate and return results
@ -73,7 +79,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
if (t == SearchType.Music or t == None) and state.model.music_search:
# query music library
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
query_end = time.time()
# collate and return results
@ -84,7 +90,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
# query markdown files
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
query_end = time.time()
# collate and return results
@ -95,7 +101,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
# query transactions
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
query_end = time.time()
# collate and return results
@ -131,11 +137,13 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
count=results_count)
collate_end = time.time()
if state.verbose > 1:
# Cache results
state.query_cache[query_cache_key] = results
if query_start and query_end:
print(f"Query took {query_end - query_start:.3f} seconds")
logger.debug(f"Query took {query_end - query_start:.3f} seconds")
if collate_start and collate_end:
print(f"Collating results took {collate_end - collate_start:.3f} seconds")
logger.debug(f"Collating results took {collate_end - collate_start:.3f} seconds")
return results

View file

@ -0,0 +1,16 @@
# Standard Packages
from abc import ABC, abstractmethod
class BaseFilter(ABC):
@abstractmethod
def load(self, *args, **kwargs):
pass
@abstractmethod
def can_filter(self, raw_query:str) -> bool:
pass
@abstractmethod
def apply(self, query:str, raw_entries:list[str]) -> tuple[str, set[int]]:
pass

View file

@ -1,15 +1,24 @@
# Standard Packages
import re
import time
import logging
from collections import defaultdict
from datetime import timedelta, datetime
from dateutil.relativedelta import relativedelta, MO
from dateutil.relativedelta import relativedelta
from math import inf
# External Packages
import torch
import dateparser as dtparse
# Internal Packages
from src.search_filter.base_filter import BaseFilter
from src.utils.helpers import LRU
class DateFilter:
logger = logging.getLogger(__name__)
class DateFilter(BaseFilter):
# Date Range Filter Regexes
# Example filter queries:
# - dt>="yesterday" dt<"tomorrow"
@ -17,46 +26,73 @@ class DateFilter:
# - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
def can_filter(self, raw_query):
"Check if query contains date filters"
return self.extract_date_range(raw_query) is not None
def __init__(self, entry_key='raw'):
self.entry_key = entry_key
self.date_to_entry_ids = defaultdict(set)
self.cache = LRU()
def filter(self, query, entries, embeddings, entry_key='raw'):
"Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query
query_daterange = self.extract_date_range(query)
# if no date in query, return all entries
if query_daterange is None:
return query, entries, embeddings
# remove date range filter from query
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
# find entries containing any dates that fall with date range specified in query
entries_to_include = set()
def load(self, entries, **_):
start = time.time()
for id, entry in enumerate(entries):
# Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]):
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]):
# Convert date string in entry to unix timestamp
try:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
except ValueError:
continue
self.date_to_entry_ids[date_in_entry].add(id)
end = time.time()
logger.debug(f"Created date filter index: {end - start} seconds")
def can_filter(self, raw_query):
"Check if query contains date filters"
return self.extract_date_range(raw_query) is not None
def apply(self, query, raw_entries):
"Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query
start = time.time()
query_daterange = self.extract_date_range(query)
end = time.time()
logger.debug(f"Extract date range to filter from query: {end - start} seconds")
# if no date in query, return all entries
if query_daterange is None:
return query, set(range(len(raw_entries)))
# remove date range filter from query
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
# return results from cache if exists
cache_key = tuple(query_daterange)
if cache_key in self.cache:
logger.info(f"Return date filter results from cache")
entries_to_include = self.cache[cache_key]
return query, entries_to_include
if not self.date_to_entry_ids:
self.load(raw_entries)
# find entries containing any dates that fall with date range specified in query
start = time.time()
entries_to_include = set()
for date_in_entry in self.date_to_entry_ids.keys():
# Check if date in entry is within date range specified in query
if query_daterange[0] <= date_in_entry < query_daterange[1]:
entries_to_include.add(id)
break
entries_to_include |= self.date_to_entry_ids[date_in_entry]
end = time.time()
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
# delete entries (and their embeddings) marked for exclusion
entries_to_exclude = set(range(len(entries))) - entries_to_include
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
# cache results
self.cache[cache_key] = entries_to_include
return query, entries, embeddings
return query, entries_to_include
def extract_date_range(self, query):

View file

@ -1,57 +0,0 @@
# Standard Packages
import re
# External Packages
import torch
class ExplicitFilter:
def can_filter(self, raw_query):
"Check if query contains explicit filters"
# Extract explicit query portion with required, blocked words to filter from natural query
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
return len(required_words) != 0 or len(blocked_words) != 0
def filter(self, raw_query, entries, embeddings, entry_key='raw'):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from explicit required, blocked words filters
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
if len(required_words) == 0 and len(blocked_words) == 0:
return query, entries, embeddings
# convert each entry to a set of words
# split on fullstop, comma, colon, tab, newline or any brackets
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
entries_by_word_set = [set(word.lower()
for word
in re.split(entry_splitter, entry[entry_key])
if word != "")
for entry in entries]
# track id of entries to exclude
entries_to_exclude = set()
# mark entries that do not contain all required_words for exclusion
if len(required_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if not required_words.issubset(words_in_entry):
entries_to_exclude.add(id)
# mark entries that contain any blocked_words for exclusion
if len(blocked_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if words_in_entry.intersection(blocked_words):
entries_to_exclude.add(id)
# delete entries (and their embeddings) marked for exclusion
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
return query, entries, embeddings

View file

@ -0,0 +1,79 @@
# Standard Packages
import re
import fnmatch
import time
import logging
from collections import defaultdict
# Internal Packages
from src.search_filter.base_filter import BaseFilter
from src.utils.helpers import LRU
logger = logging.getLogger(__name__)
class FileFilter(BaseFilter):
file_filter_regex = r'file:"(.+?)" ?'
def __init__(self, entry_key='file'):
self.entry_key = entry_key
self.file_to_entry_map = defaultdict(set)
self.cache = LRU()
def load(self, entries, *args, **kwargs):
start = time.time()
for id, entry in enumerate(entries):
self.file_to_entry_map[entry[self.entry_key]].add(id)
end = time.time()
logger.debug(f"Created file filter index: {end - start} seconds")
def can_filter(self, raw_query):
return re.search(self.file_filter_regex, raw_query) is not None
def apply(self, raw_query, raw_entries):
# Extract file filters from raw query
start = time.time()
raw_files_to_search = re.findall(self.file_filter_regex, raw_query)
if not raw_files_to_search:
return raw_query, set(range(len(raw_entries)))
# Convert simple file filters with no path separator into regex
# e.g. "file:notes.org" -> "file:.*notes.org"
files_to_search = []
for file in sorted(raw_files_to_search):
if '/' not in file and '\\' not in file and '*' not in file:
files_to_search += [f'*{file}']
else:
files_to_search += [file]
end = time.time()
logger.debug(f"Extract files_to_search from query: {end - start} seconds")
# Return item from cache if exists
query = re.sub(self.file_filter_regex, '', raw_query).strip()
cache_key = tuple(files_to_search)
if cache_key in self.cache:
logger.info(f"Return file filter results from cache")
included_entry_indices = self.cache[cache_key]
return query, included_entry_indices
if not self.file_to_entry_map:
self.load(raw_entries, regenerate=False)
# Mark entries that contain any blocked_words for exclusion
start = time.time()
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file]
for entry_file in self.file_to_entry_map.keys()
for search_file in files_to_search
if fnmatch.fnmatch(entry_file, search_file)], set())
if not included_entry_indices:
return query, {}
end = time.time()
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
# Cache results
self.cache[cache_key] = included_entry_indices
return query, included_entry_indices

View file

@ -0,0 +1,96 @@
# Standard Packages
import re
import time
import logging
from collections import defaultdict
# Internal Packages
from src.search_filter.base_filter import BaseFilter
from src.utils.helpers import LRU
logger = logging.getLogger(__name__)
class WordFilter(BaseFilter):
# Filter Regex
required_regex = r'\+"([a-zA-Z0-9_-]+)" ?'
blocked_regex = r'\-"([a-zA-Z0-9_-]+)" ?'
def __init__(self, entry_key='raw'):
self.entry_key = entry_key
self.word_to_entry_index = defaultdict(set)
self.cache = LRU()
def load(self, entries, regenerate=False):
start = time.time()
self.cache = {} # Clear cache on filter (re-)load
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
# Create map of words to entries they exist in
for entry_index, entry in enumerate(entries):
for word in re.split(entry_splitter, entry[self.entry_key].lower()):
if word == '':
continue
self.word_to_entry_index[word].add(entry_index)
end = time.time()
logger.debug(f"Created word filter index: {end - start} seconds")
return self.word_to_entry_index
def can_filter(self, raw_query):
"Check if query contains word filters"
required_words = re.findall(self.required_regex, raw_query)
blocked_words = re.findall(self.blocked_regex, raw_query)
return len(required_words) != 0 or len(blocked_words) != 0
def apply(self, raw_query, raw_entries):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from required, blocked words filters
start = time.time()
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)])
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip()
end = time.time()
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
if len(required_words) == 0 and len(blocked_words) == 0:
return query, set(range(len(raw_entries)))
# Return item from cache if exists
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
if cache_key in self.cache:
logger.info(f"Return word filter results from cache")
included_entry_indices = self.cache[cache_key]
return query, included_entry_indices
if not self.word_to_entry_index:
self.load(raw_entries, regenerate=False)
start = time.time()
# mark entries that contain all required_words for inclusion
entries_with_all_required_words = set(range(len(raw_entries)))
if len(required_words) > 0:
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
# mark entries that contain any blocked_words for exclusion
entries_with_any_blocked_words = set()
if len(blocked_words) > 0:
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
end = time.time()
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
# get entries satisfying inclusion and exclusion filters
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
# Cache results
self.cache[cache_key] = included_entry_indices
return query, included_entry_indices

View file

@ -1,9 +1,10 @@
# Standard Packages
import argparse
import glob
import pathlib
import copy
import shutil
import time
import logging
# External Packages
from sentence_transformers import SentenceTransformer, util
@ -18,6 +19,10 @@ from src.utils.config import ImageSearchModel
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig
# Create Logger
logger = logging.getLogger(__name__)
def initialize_model(search_config: ImageSearchConfig):
# Initialize Model
torch.set_num_threads(4)
@ -37,41 +42,43 @@ def initialize_model(search_config: ImageSearchConfig):
return encoder
def extract_entries(image_directories, verbose=0):
def extract_entries(image_directories):
image_names = []
for image_directory in image_directories:
image_directory = resolve_absolute_path(image_directory, strict=True)
image_names.extend(list(image_directory.glob('*.jpg')))
image_names.extend(list(image_directory.glob('*.jpeg')))
if verbose > 0:
if logger.level >= logging.INFO:
image_directory_names = ', '.join([str(image_directory) for image_directory in image_directories])
print(f'Found {len(image_names)} images in {image_directory_names}')
logger.info(f'Found {len(image_names)} images in {image_directory_names}')
return sorted(image_names)
def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0):
def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate, verbose)
image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate, verbose)
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate)
image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate)
return image_embeddings, image_metadata_embeddings
def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False, verbose=0):
image_embeddings = None
def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False):
# Load pre-computed image embeddings from file if exists
if resolve_absolute_path(embeddings_file).exists() and not regenerate:
image_embeddings = torch.load(embeddings_file)
if verbose > 0:
print(f"Loaded pre-computed embeddings from {embeddings_file}")
logger.info(f"Loaded {len(image_embeddings)} image embeddings from {embeddings_file}")
# Else compute the image embeddings from scratch, which can take a while
elif image_embeddings is None:
else:
image_embeddings = []
for index in trange(0, len(image_names), batch_size):
images = [Image.open(image_name) for image_name in image_names[index:index+batch_size]]
images = []
for image_name in image_names[index:index+batch_size]:
image = Image.open(image_name)
# Resize images to max width of 640px for faster processing
image.thumbnail((640, image.height))
images += [image]
image_embeddings += encoder.encode(
images,
convert_to_tensor=True,
@ -82,8 +89,7 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
# Save computed image embeddings to file
torch.save(image_embeddings, embeddings_file)
if verbose > 0:
print(f"Saved computed embeddings to {embeddings_file}")
logger.info(f"Saved computed embeddings to {embeddings_file}")
return image_embeddings
@ -94,8 +100,7 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
# Load pre-computed image metadata embedding file if exists
if use_xmp_metadata and resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate:
image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata")
if verbose > 0:
print(f"Loaded pre-computed embeddings from {embeddings_file}_metadata")
logger.info(f"Loaded pre-computed embeddings from {embeddings_file}_metadata")
# Else compute the image metadata embeddings from scratch, which can take a while
if use_xmp_metadata and image_metadata_embeddings is None:
@ -108,16 +113,15 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
convert_to_tensor=True,
batch_size=min(len(image_metadata), batch_size))
except RuntimeError as e:
print(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}")
logger.error(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}")
continue
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
if verbose > 0:
print(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
logger.info(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
return image_metadata_embeddings
def extract_metadata(image_name, verbose=0):
def extract_metadata(image_name):
with exiftool.ExifTool() as et:
image_metadata = et.get_tags(["XMP:Subject", "XMP:Description"], str(image_name))
image_metadata_subjects = set([subject.split(":")[1] for subject in image_metadata.get("XMP:Subject", "") if ":" in subject])
@ -126,8 +130,7 @@ def extract_metadata(image_name, verbose=0):
if len(image_metadata_subjects) > 0:
image_processed_metadata += ". " + ", ".join(image_metadata_subjects)
if verbose > 2:
print(f"{image_name}:\t{image_processed_metadata}")
logger.debug(f"{image_name}:\t{image_processed_metadata}")
return image_processed_metadata
@ -137,26 +140,34 @@ def query(raw_query, count, model: ImageSearchModel):
if pathlib.Path(raw_query).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query), strict=True)
query = copy.deepcopy(Image.open(query_imagepath))
if model.verbose > 0:
print(f"Find Images similar to Image at {query_imagepath}")
query.thumbnail((640, query.height)) # scale down image for faster processing
logger.info(f"Find Images similar to Image at {query_imagepath}")
else:
query = raw_query
if model.verbose > 0:
print(f"Find Images by Text: {query}")
logger.info(f"Find Images by Text: {query}")
# Now we encode the query (which can either be an image or a text string)
start = time.time()
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
end = time.time()
logger.debug(f"Query Encode Time: {end - start:.3f} seconds")
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
start = time.time()
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
for result
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
end = time.time()
logger.debug(f"Search Time: {end - start:.3f} seconds")
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if model.image_metadata_embeddings:
start = time.time()
metadata_hits = {result['corpus_id']: result['score']
for result
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
end = time.time()
logger.debug(f"Metadata Search Time: {end - start:.3f} seconds")
# Sum metadata, image scores of the highest ranked images
for corpus_id, score in metadata_hits.items():
@ -219,7 +230,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
return results
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel:
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
# Initialize Model
encoder = initialize_model(search_config)
@ -227,9 +238,13 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
absolute_image_files, filtered_image_files = set(), set()
if config.input_directories:
image_directories = [resolve_absolute_path(directory, strict=True) for directory in config.input_directories]
absolute_image_files = set(extract_entries(image_directories, verbose))
absolute_image_files = set(extract_entries(image_directories))
if config.input_filter:
filtered_image_files = set(glob.glob(get_absolute_path(config.input_filter)))
filtered_image_files = {
filtered_file
for input_filter in config.input_filter
for filtered_file in glob.glob(get_absolute_path(input_filter))
}
all_image_files = sorted(list(absolute_image_files | filtered_image_files))
@ -241,38 +256,9 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
embeddings_file,
batch_size=config.batch_size,
regenerate=regenerate,
use_xmp_metadata=config.use_xmp_metadata,
verbose=verbose)
use_xmp_metadata=config.use_xmp_metadata)
return ImageSearchModel(all_image_files,
image_embeddings,
image_metadata_embeddings,
encoder,
verbose)
if __name__ == '__main__':
# Setup Argument Parser
parser = argparse.ArgumentParser(description="Semantic Search on Images")
parser.add_argument('--image-directory', '-i', required=True, type=pathlib.Path, help="Image directory to query")
parser.add_argument('--embeddings-file', '-e', default='image_embeddings.pt', type=pathlib.Path, help="File to save/load model embeddings to/from. Default: ./embeddings.pt")
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings of Images in Image Directory . Default: false")
parser.add_argument('--results-count', '-n', default=5, type=int, help="Number of results to render. Default: 5")
parser.add_argument('--interactive', action='store_true', default=False, help="Interactive mode allows user to run queries on the model. Default: true")
parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0")
args = parser.parse_args()
image_names, image_embeddings, image_metadata_embeddings, model = setup(args.image_directory, args.embeddings_file, regenerate=args.regenerate)
# Run User Queries on Entries in Interactive Mode
while args.interactive:
# get query from user
user_query = input("Enter your query: ")
if user_query == "exit":
exit(0)
# query images
hits = query(user_query, image_embeddings, image_metadata_embeddings, model, args.results_count, args.verbose)
# render results
render_results(hits, image_names, args.image_directory, count=args.results_count)
encoder)

View file

@ -1,21 +1,23 @@
# Standard Packages
import argparse
import pathlib
from copy import deepcopy
import logging
import time
# External Packages
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from src.search_filter.base_filter import BaseFilter
# Internal Packages
from src.utils import state
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
from src.utils.config import TextSearchModel
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
from src.utils.jsonl import load_jsonl
logger = logging.getLogger(__name__)
def initialize_model(search_config: TextSearchConfig):
"Initialize model for semantic search on text"
torch.set_num_threads(4)
@ -46,73 +48,85 @@ def initialize_model(search_config: TextSearchConfig):
return bi_encoder, cross_encoder, top_k
def extract_entries(jsonl_file, verbose=0):
def extract_entries(jsonl_file):
"Load entries from compressed jsonl"
return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'}
for entry
in load_jsonl(jsonl_file, verbose=verbose)]
return load_jsonl(jsonl_file)
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0):
def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate=False):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
# Load pre-computed embeddings from file if exists
new_entries = []
# Load pre-computed embeddings from file if exists and update them if required
if embeddings_file.exists() and not regenerate:
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
if verbose > 0:
print(f"Loaded embeddings from {embeddings_file}")
logger.info(f"Loaded embeddings from {embeddings_file}")
else: # Else compute the corpus_embeddings from scratch, which can take a while
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=state.device, show_progress_bar=True)
# Encode any new entries in the corpus and update corpus embeddings
new_entries = [entry['compiled'] for id, entry in entries_with_ids if id is None]
if new_entries:
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
existing_entry_ids = [id for id, _ in entries_with_ids if id is not None]
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor()
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
# Else compute the corpus embeddings from scratch
else:
new_entries = [entry['compiled'] for _, entry in entries_with_ids]
corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
# Save regenerated or updated embeddings to file
if new_entries:
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
torch.save(corpus_embeddings, embeddings_file)
if verbose > 0:
print(f"Computed embeddings and saved them to {embeddings_file}")
logger.info(f"Computed embeddings and saved them to {embeddings_file}")
return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = [], verbose=0):
def query(raw_query: str, model: TextSearchModel, rank_results=False):
"Search for entries that answer the query"
query = raw_query
# Use deep copy of original embeddings, entries to filter if query contains filters
start = time.time()
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
if filters_in_query:
corpus_embeddings = deepcopy(model.corpus_embeddings)
entries = deepcopy(model.entries)
else:
corpus_embeddings = model.corpus_embeddings
entries = model.entries
end = time.time()
if verbose > 1:
print(f"Copy Time: {end - start:.3f} seconds")
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
# Filter query, entries and embeddings before semantic search
start = time.time()
start_filter = time.time()
included_entry_indices = set(range(len(entries)))
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
for filter in filters_in_query:
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
query, included_entry_indices_by_filter = filter.apply(query, entries)
included_entry_indices.intersection_update(included_entry_indices_by_filter)
# Get entries (and associated embeddings) satisfying all filters
if not included_entry_indices:
return [], []
else:
start = time.time()
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices)))
end = time.time()
if verbose > 1:
print(f"Filter Time: {end - start:.3f} seconds")
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
end_filter = time.time()
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds")
if entries is None or len(entries) == 0:
return [], []
# If query only had filters it'll be empty now. So short-circuit and return results.
if query.strip() == "":
hits = [{"corpus_id": id, "score": 1.0} for id, _ in enumerate(entries)]
return hits, entries
# Encode the query using the bi-encoder
start = time.time()
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding)
end = time.time()
if verbose > 1:
print(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
logger.debug(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
# Find relevant entries for the query
start = time.time()
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
end = time.time()
if verbose > 1:
print(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
logger.debug(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
# Score all retrieved entries using the cross-encoder
if rank_results:
@ -120,8 +134,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp)
end = time.time()
if verbose > 1:
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
@ -133,8 +146,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
end = time.time()
if verbose > 1:
print(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
return hits, entries
@ -167,50 +179,26 @@ def collate_results(hits, entries, count=5):
in hits[0:count]]
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel:
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
if not config.compressed_jsonl.exists() or regenerate:
text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose)
previous_entries = extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
entries_with_indices = text_to_jsonl(config, previous_entries)
# Extract Entries
entries = extract_entries(config.compressed_jsonl, verbose)
# Extract Updated Entries
entries = extract_entries(config.compressed_jsonl)
if is_none_or_empty(entries):
raise ValueError(f"No valid entries found in specified files: {config.input_files} or {config.input_filter}")
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
# Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose)
corpus_embeddings = compute_embeddings(entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
for filter in filters:
filter.load(entries, regenerate=regenerate)
if __name__ == '__main__':
# Setup Argument Parser
parser = argparse.ArgumentParser(description="Map Text files into (compressed) JSONL format")
parser.add_argument('--input-files', '-i', nargs='*', help="List of Text files to process")
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for Text files to process")
parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path("text.jsonl.gz"), help="Compressed JSONL to compute embeddings from")
parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path("text_embeddings.pt"), help="File to save/load model embeddings to/from")
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from text files. Default: false")
parser.add_argument('--results-count', '-n', default=5, type=int, help="Number of results to render. Default: 5")
parser.add_argument('--interactive', action='store_true', default=False, help="Interactive mode allows user to run queries on the model. Default: true")
parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0")
args = parser.parse_args()
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate, args.verbose)
# Run User Queries on Entries in Interactive Mode
while args.interactive:
# get query from user
user_query = input("Enter your query: ")
if user_query == "exit":
exit(0)
# query notes
hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k)
# render results
render_results(hits, entries, count=args.results_count)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)

View file

@ -1,6 +1,7 @@
# Standard Packages
import argparse
import pathlib
from importlib.metadata import version
# Internal Packages
from src.utils.helpers import resolve_absolute_path
@ -16,9 +17,15 @@ def cli(args=None):
parser.add_argument('--host', type=str, default='127.0.0.1', help="Host address of the server. Default: 127.0.0.1")
parser.add_argument('--port', '-p', type=int, default=8000, help="Port of the server. Default: 8000")
parser.add_argument('--socket', type=pathlib.Path, help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock")
parser.add_argument('--version', '-V', action='store_true', help="Print the installed Khoj version and exit")
args = parser.parse_args(args)
if args.version:
# Show version of khoj installed and exit
print(version('khoj-assistant'))
exit(0)
# Normalize config_file path to absolute path
args.config_file = resolve_absolute_path(args.config_file)

View file

@ -5,6 +5,7 @@ from pathlib import Path
# Internal Packages
from src.utils.rawconfig import ConversationProcessorConfig
from src.search_filter.base_filter import BaseFilter
class SearchType(str, Enum):
@ -21,23 +22,22 @@ class ProcessorType(str, Enum):
class TextSearchModel():
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose):
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder
self.cross_encoder = cross_encoder
self.filters = filters
self.top_k = top_k
self.verbose = verbose
class ImageSearchModel():
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder, verbose):
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder):
self.image_encoder = image_encoder
self.image_names = image_names
self.image_embeddings = image_embeddings
self.image_metadata_embeddings = image_metadata_embeddings
self.image_encoder = image_encoder
self.verbose = verbose
@dataclass
@ -51,12 +51,11 @@ class SearchModels():
class ConversationProcessorConfigModel():
def __init__(self, processor_config: ConversationProcessorConfig, verbose: bool):
def __init__(self, processor_config: ConversationProcessorConfig):
self.openai_api_key = processor_config.openai_api_key
self.conversation_logfile = Path(processor_config.conversation_logfile)
self.chat_session = ''
self.meta_log = []
self.verbose = verbose
self.meta_log: dict = {}
@dataclass

View file

@ -2,7 +2,7 @@ from pathlib import Path
app_root_directory = Path(__file__).parent.parent.parent
web_directory = app_root_directory / 'src/interface/web/'
empty_escape_sequences = r'\n|\r\t '
empty_escape_sequences = '\n|\r|\t| '
# default app config to use
default_config = {
@ -11,7 +11,8 @@ default_config = {
'input-files': None,
'input-filter': None,
'compressed-jsonl': '~/.khoj/content/org/org.jsonl.gz',
'embeddings-file': '~/.khoj/content/org/org_embeddings.pt'
'embeddings-file': '~/.khoj/content/org/org_embeddings.pt',
'index_heading_entries': False
},
'markdown': {
'input-files': None,

View file

@ -1,7 +1,11 @@
# Standard Packages
import pathlib
from pathlib import Path
import sys
import time
import hashlib
from os.path import join
from collections import OrderedDict
from typing import Optional, Union
def is_none_or_empty(item):
@ -12,12 +16,12 @@ def to_snake_case_from_dash(item: str):
return item.replace('_', '-')
def get_absolute_path(filepath):
return str(pathlib.Path(filepath).expanduser().absolute())
def get_absolute_path(filepath: Union[str, Path]) -> str:
return str(Path(filepath).expanduser().absolute())
def resolve_absolute_path(filepath, strict=False):
return pathlib.Path(filepath).expanduser().absolute().resolve(strict=strict)
def resolve_absolute_path(filepath: Union[str, Optional[Path]], strict=False) -> Path:
return Path(filepath).expanduser().absolute().resolve(strict=strict)
def get_from_dict(dictionary, *args):
@ -61,3 +65,56 @@ def load_model(model_name, model_dir, model_type, device:str=None):
def is_pyinstaller_app():
"Returns true if the app is running from Native GUI created by PyInstaller"
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
class LRU(OrderedDict):
def __init__(self, *args, capacity=128, **kwargs):
self.capacity = capacity
super().__init__(*args, **kwargs)
def __getitem__(self, key):
value = super().__getitem__(key)
self.move_to_end(key)
return value
def __setitem__(self, key, value):
super().__setitem__(key, value)
if len(self) > self.capacity:
oldest = next(iter(self))
del self[oldest]
def mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=None):
# Hash all current and previous entries to identify new entries
start = time.time()
current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), current_entries))
previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), previous_entries))
end = time.time()
logger.debug(f"Hash previous, current entries: {end - start} seconds")
start = time.time()
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
# All entries that did not exist in the previous set are to be added
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
# All entries that exist in both current and previous sets are kept
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
# Mark new entries with no ids for later embeddings generation
new_entries = [
(None, hash_to_current_entries[entry_hash])
for entry_hash in new_entry_hashes
]
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
existing_entries = [
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
for entry_hash in existing_entry_hashes
]
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
entries_with_ids = existing_entries_sorted + new_entries
end = time.time()
logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds")
return entries_with_ids

View file

@ -1,13 +1,17 @@
# Standard Packages
import json
import gzip
import logging
# Internal Packages
from src.utils.constants import empty_escape_sequences
from src.utils.helpers import get_absolute_path
def load_jsonl(input_path, verbose=0):
logger = logging.getLogger(__name__)
def load_jsonl(input_path):
"Read List of JSON objects from JSON line file"
# Initialize Variables
data = []
@ -27,13 +31,12 @@ def load_jsonl(input_path, verbose=0):
jsonl_file.close()
# Log JSONL entries loaded
if verbose > 0:
print(f'Loaded {len(data)} records from {input_path}')
logger.info(f'Loaded {len(data)} records from {input_path}')
return data
def dump_jsonl(jsonl_data, output_path, verbose=0):
def dump_jsonl(jsonl_data, output_path):
"Write List of JSON objects to JSON line file"
# Create output directory, if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True)
@ -41,16 +44,14 @@ def dump_jsonl(jsonl_data, output_path, verbose=0):
with open(output_path, 'w', encoding='utf-8') as f:
f.write(jsonl_data)
if verbose > 0:
print(f'Wrote {len(jsonl_data)} lines to jsonl at {output_path}')
logger.info(f'Wrote jsonl data to {output_path}')
def compress_jsonl_data(jsonl_data, output_path, verbose=0):
def compress_jsonl_data(jsonl_data, output_path):
# Create output directory, if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True)
with gzip.open(output_path, 'wt') as gzip_file:
gzip_file.write(jsonl_data)
if verbose > 0:
print(f'Wrote {len(jsonl_data)} lines to gzip compressed jsonl at {output_path}')
logger.info(f'Wrote jsonl data to gzip compressed jsonl at {output_path}')

View file

@ -6,7 +6,7 @@ from typing import List, Optional
from pydantic import BaseModel, validator
# Internal Packages
from src.utils.helpers import to_snake_case_from_dash
from src.utils.helpers import to_snake_case_from_dash, is_none_or_empty
class ConfigBase(BaseModel):
class Config:
@ -15,26 +15,27 @@ class ConfigBase(BaseModel):
class TextContentConfig(ConfigBase):
input_files: Optional[List[Path]]
input_filter: Optional[str]
input_filter: Optional[List[str]]
compressed_jsonl: Path
embeddings_file: Path
index_heading_entries: Optional[bool] = False
@validator('input_filter')
def input_filter_or_files_required(cls, input_filter, values, **kwargs):
if input_filter is None and ('input_files' not in values or values["input_files"] is None):
if is_none_or_empty(input_filter) and ('input_files' not in values or values["input_files"] is None):
raise ValueError("Either input_filter or input_files required in all content-type.<text_search> section of Khoj config file")
return input_filter
class ImageContentConfig(ConfigBase):
input_directories: Optional[List[Path]]
input_filter: Optional[str]
input_filter: Optional[List[str]]
embeddings_file: Path
use_xmp_metadata: bool
batch_size: int
@validator('input_filter')
def input_filter_or_directories_required(cls, input_filter, values, **kwargs):
if input_filter is None and ('input_directories' not in values or values["input_directories"] is None):
if is_none_or_empty(input_filter) and ('input_directories' not in values or values["input_directories"] is None):
raise ValueError("Either input_filter or input_directories required in all content-type.image section of Khoj config file")
return input_filter

View file

@ -1,22 +1,25 @@
# Standard Packages
from packaging import version
# External Packages
import torch
from pathlib import Path
# Internal Packages
from src.utils.config import SearchModels, ProcessorConfigModel
from src.utils.helpers import LRU
from src.utils.rawconfig import FullConfig
# Application Global State
config = FullConfig()
model = SearchModels()
processor_config = ProcessorConfigModel()
config_file: Path = ""
config_file: Path = None
verbose: int = 0
host: str = None
port: int = None
cli_args = None
cli_args: list[str] = None
query_cache = LRU()
if torch.cuda.is_available():
# Use CUDA GPU

View file

@ -5,12 +5,13 @@ from pathlib import Path
import yaml
# Internal Packages
from src.utils.helpers import get_absolute_path, resolve_absolute_path
from src.utils.rawconfig import FullConfig
# Do not emit tags when dumping to YAML
yaml.emitter.Emitter.process_tag = lambda self, *args, **kwargs: None
def save_config_to_file(yaml_config: dict, yaml_config_file: Path):
"Write config to YML file"
# Create output directory, if it doesn't exist

View file

@ -1,78 +1,65 @@
# Standard Packages
# External Packages
import pytest
# Internal Packages
from src.search_type import image_search, text_search
from src.utils.config import SearchType
from src.utils.helpers import resolve_absolute_path
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.utils import state
from src.search_filter.date_filter import DateFilter
from src.search_filter.word_filter import WordFilter
from src.search_filter.file_filter import FileFilter
@pytest.fixture(scope='session')
def search_config(tmp_path_factory):
model_dir = tmp_path_factory.mktemp('data')
def search_config() -> SearchConfig:
model_dir = resolve_absolute_path('~/.khoj/search')
model_dir.mkdir(parents=True, exist_ok=True)
search_config = SearchConfig()
search_config.symmetric = TextSearchConfig(
encoder = "sentence-transformers/all-MiniLM-L6-v2",
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory = model_dir
model_directory = model_dir / 'symmetric/'
)
search_config.asymmetric = TextSearchConfig(
encoder = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory = model_dir
model_directory = model_dir / 'asymmetric/'
)
search_config.image = ImageSearchConfig(
encoder = "sentence-transformers/clip-ViT-B-32",
model_directory = model_dir
model_directory = model_dir / 'image/'
)
return search_config
@pytest.fixture(scope='session')
def model_dir(search_config):
model_dir = search_config.asymmetric.model_directory
def content_config(tmp_path_factory, search_config: SearchConfig):
content_dir = tmp_path_factory.mktemp('content')
# Generate Image Embeddings from Test Images
content_config = ContentConfig()
content_config.image = ImageContentConfig(
input_directories = ['tests/data/images'],
embeddings_file = model_dir.joinpath('image_embeddings.pt'),
batch_size = 10,
embeddings_file = content_dir.joinpath('image_embeddings.pt'),
batch_size = 1,
use_xmp_metadata = False)
image_search.setup(content_config.image, search_config.image, regenerate=False, verbose=True)
image_search.setup(content_config.image, search_config.image, regenerate=False)
# Generate Notes Embeddings from Test Notes
content_config.org = TextContentConfig(
input_files = None,
input_filter = 'tests/data/org/*.org',
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
input_filter = ['tests/data/org/*.org'],
compressed_jsonl = content_dir.joinpath('notes.jsonl.gz'),
embeddings_file = content_dir.joinpath('note_embeddings.pt'))
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, verbose=True)
return model_dir
@pytest.fixture(scope='session')
def content_config(model_dir):
content_config = ContentConfig()
content_config.org = TextContentConfig(
input_files = None,
input_filter = 'tests/data/org/*.org',
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
content_config.image = ImageContentConfig(
input_directories = ['tests/data/images'],
embeddings_file = model_dir.joinpath('image_embeddings.pt'),
batch_size = 1,
use_xmp_metadata = False)
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
return content_config

View file

@ -1,9 +1,10 @@
content-type:
org:
input-files: [ "~/first_from_config.org", "~/second_from_config.org" ]
input-filter: "*.org"
input-filter: ["*.org", "~/notes/*.org"]
compressed-jsonl: ".notes.json.gz"
embeddings-file: ".note_embeddings.pt"
index-header-entries: true
search-type:
asymmetric:

View file

@ -0,0 +1,112 @@
# Standard Packages
import json
# Internal Packages
from src.processor.ledger.beancount_to_jsonl import extract_beancount_transactions, convert_transactions_to_maps, convert_transaction_maps_to_jsonl, get_beancount_files
def test_no_transactions_in_file(tmp_path):
"Handle file with no transactions."
# Arrange
entry = f'''
- Bullet point 1
- Bullet point 2
'''
beancount_file = create_file(tmp_path, entry)
# Act
# Extract Entries from specified Beancount files
entry_nodes, file_to_entries = extract_beancount_transactions(beancount_files=[beancount_file])
# Process Each Entry from All Beancount Files
jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entry_nodes, file_to_entries))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
assert len(jsonl_data) == 0
def test_single_beancount_transaction_to_jsonl(tmp_path):
"Convert transaction from single file to jsonl."
# Arrange
entry = f'''
1984-04-01 * "Payee" "Narration"
Expenses:Test:Test 1.00 KES
Assets:Test:Test -1.00 KES
'''
beancount_file = create_file(tmp_path, entry)
# Act
# Extract Entries from specified Beancount files
entries, entry_to_file_map = extract_beancount_transactions(beancount_files=[beancount_file])
# Process Each Entry from All Beancount Files
jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entries, entry_to_file_map))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
assert len(jsonl_data) == 1
def test_multiple_transactions_to_jsonl(tmp_path):
"Convert multiple transactions from single file to jsonl."
# Arrange
entry = f'''
1984-04-01 * "Payee" "Narration"
Expenses:Test:Test 1.00 KES
Assets:Test:Test -1.00 KES
\t\r
1984-04-01 * "Payee" "Narration"
Expenses:Test:Test 1.00 KES
Assets:Test:Test -1.00 KES
'''
beancount_file = create_file(tmp_path, entry)
# Act
# Extract Entries from specified Beancount files
entries, entry_to_file_map = extract_beancount_transactions(beancount_files=[beancount_file])
# Process Each Entry from All Beancount Files
jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entries, entry_to_file_map))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
assert len(jsonl_data) == 2
def test_get_beancount_files(tmp_path):
"Ensure Beancount files specified via input-filter, input-files extracted"
# Arrange
# Include via input-filter globs
group1_file1 = create_file(tmp_path, filename="group1-file1.bean")
group1_file2 = create_file(tmp_path, filename="group1-file2.bean")
group2_file1 = create_file(tmp_path, filename="group2-file1.beancount")
group2_file2 = create_file(tmp_path, filename="group2-file2.beancount")
# Include via input-file field
file1 = create_file(tmp_path, filename="ledger.bean")
# Not included by any filter
create_file(tmp_path, filename="not-included-ledger.bean")
create_file(tmp_path, filename="not-included-text.txt")
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
# Setup input-files, input-filters
input_files = [tmp_path / 'ledger.bean']
input_filter = [tmp_path / 'group1*.bean', tmp_path / 'group2*.beancount']
# Act
extracted_org_files = get_beancount_files(input_files, input_filter)
# Assert
assert len(extracted_org_files) == 5
assert extracted_org_files == expected_files
# Helper Functions
def create_file(tmp_path, entry=None, filename="ledger.beancount"):
beancount_file = tmp_path / filename
beancount_file.touch()
if entry:
beancount_file.write_text(entry)
return beancount_file

View file

@ -2,9 +2,6 @@
from pathlib import Path
from random import random
# External Modules
import pytest
# Internal Packages
from src.utils.cli import cli
from src.utils.helpers import resolve_absolute_path

View file

@ -1,17 +1,20 @@
# Standard Modules
from io import BytesIO
from PIL import Image
from urllib.parse import quote
# External Packages
from fastapi.testclient import TestClient
import pytest
# Internal Packages
from src.main import app
from src.utils.state import model, config
from src.search_type import text_search, image_search
from src.utils.rawconfig import ContentConfig, SearchConfig
from src.processor.org_mode import org_to_jsonl
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.search_filter.word_filter import WordFilter
from src.search_filter.file_filter import FileFilter
# Arrange
@ -22,7 +25,7 @@ client = TestClient(app)
# ----------------------------------------------------------------------------------------------------
def test_search_with_invalid_content_type():
# Arrange
user_query = "How to call Khoj from Emacs?"
user_query = quote("How to call Khoj from Emacs?")
# Act
response = client.get(f"/search?q={user_query}&t=invalid_content_type")
@ -116,7 +119,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
def test_notes_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
user_query = "How to git install application?"
user_query = quote("How to git install application?")
# Act
response = client.get(f"/search?q={user_query}&n=1&t=org&r=true")
@ -129,17 +132,35 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
def test_notes_search_with_only_filters(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
user_query = "How to git install application? +Emacs"
filters = [WordFilter(), FileFilter()]
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
user_query = quote('+"Emacs" file:"*.org"')
# Act
response = client.get(f"/search?q={user_query}&n=1&t=org")
# Assert
assert response.status_code == 200
# assert actual_data contains explicitly included word "Emacs"
# assert actual_data contains word "Emacs"
search_result = response.json()[0]["entry"]
assert "Emacs" in search_result
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
filters = [WordFilter()]
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
user_query = quote('How to git install application? +"Emacs"')
# Act
response = client.get(f"/search?q={user_query}&n=1&t=org")
# Assert
assert response.status_code == 200
# assert actual_data contains word "Emacs"
search_result = response.json()[0]["entry"]
assert "Emacs" in search_result
@ -147,14 +168,15 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
user_query = "How to git install application? -clone"
filters = [WordFilter()]
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
user_query = quote('How to git install application? -"clone"')
# Act
response = client.get(f"/search?q={user_query}&n=1&t=org")
# Assert
assert response.status_code == 200
# assert actual_data does not contains explicitly excluded word "Emacs"
# assert actual_data does not contains word "Emacs"
search_result = response.json()[0]["entry"]
assert "clone" not in search_result

View file

@ -18,40 +18,34 @@ def test_date_filter():
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
q_with_no_date_filter = 'head tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_no_date_filter, entries.copy(), embeddings)
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
assert ret_query == 'head tail'
assert len(ret_emb) == 3
assert ret_entries == entries
assert entry_indices == {0, 1, 2}
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
assert ret_query == 'head tail'
assert len(ret_emb) == 0
assert ret_entries == []
assert entry_indices == set()
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail'
assert ret_entries == [entries[2]]
assert len(ret_emb) == 1
assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail'
assert ret_entries == [entries[1]]
assert len(ret_emb) == 1
assert entry_indices == {1}
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail'
assert ret_entries == [entries[2]]
assert len(ret_emb) == 1
assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail'
assert ret_entries == [entries[1], entries[2]]
assert len(ret_emb) == 2
assert entry_indices == {1, 2}
def test_extract_date_range():

112
tests/test_file_filter.py Normal file
View file

@ -0,0 +1,112 @@
# External Packages
import torch
# Application Packages
from src.search_filter.file_filter import FileFilter
def test_no_file_filter():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
q_with_no_filter = 'head tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == False
assert ret_query == 'head tail'
assert entry_indices == {0, 1, 2, 3}
def test_file_filter_with_non_existent_file():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
q_with_no_filter = 'head file:"nonexistent.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {}
def test_single_file_filter():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
q_with_no_filter = 'head file:"file 1.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {0, 2}
def test_file_filter_with_partial_match():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
q_with_no_filter = 'head file:"1.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {0, 2}
def test_file_filter_with_regex_match():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
q_with_no_filter = 'head file:"*.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {0, 1, 2, 3}
def test_multiple_file_filter():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {0, 1, 2, 3}
def arrange_content():
embeddings = torch.randn(4, 10)
entries = [
{'compiled': '', 'raw': 'First Entry', 'file': 'file 1.org'},
{'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'},
{'compiled': '', 'raw': 'Third Entry', 'file': 'file 1.org'},
{'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}]
return embeddings, entries

View file

@ -28,3 +28,18 @@ def test_merge_dicts():
# do not override existing key in priority_dict with default dict
assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'a': 2}) == {'a': 1}
def test_lru_cache():
# Test initializing cache
cache = helpers.LRU({'a': 1, 'b': 2}, capacity=2)
assert cache == {'a': 1, 'b': 2}
# Test capacity overflow
cache['c'] = 3
assert cache == {'b': 2, 'c': 3}
# Test delete least recently used item from LRU cache on capacity overflow
cache['b'] # accessing 'b' makes it the most recently used item
cache['d'] = 4 # so 'c' is deleted from the cache instead of 'b'
assert cache == {'b': 2, 'd': 4}

View file

@ -48,8 +48,13 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
image_files_url='/static/images',
count=1)
actual_image = Image.open(output_directory.joinpath(Path(results[0]["entry"]).name))
actual_image_path = output_directory.joinpath(Path(results[0]["entry"]).name)
actual_image = Image.open(actual_image_path)
expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name))
# Assert
assert expected_image == actual_image
# Cleanup
# Delete the image files copied to results directory
actual_image_path.unlink()

View file

@ -0,0 +1,109 @@
# Standard Packages
import json
# Internal Packages
from src.processor.markdown.markdown_to_jsonl import extract_markdown_entries, convert_markdown_maps_to_jsonl, convert_markdown_entries_to_maps, get_markdown_files
def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
"Convert files with no heading to jsonl."
# Arrange
entry = f'''
- Bullet point 1
- Bullet point 2
'''
markdownfile = create_file(tmp_path, entry)
# Act
# Extract Entries from specified Markdown files
entry_nodes, file_to_entries = extract_markdown_entries(markdown_files=[markdownfile])
# Process Each Entry from All Notes Files
jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entry_nodes, file_to_entries))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
assert len(jsonl_data) == 1
def test_single_markdown_entry_to_jsonl(tmp_path):
"Convert markdown entry from single file to jsonl."
# Arrange
entry = f'''### Heading
\t\r
Body Line 1
'''
markdownfile = create_file(tmp_path, entry)
# Act
# Extract Entries from specified Markdown files
entries, entry_to_file_map = extract_markdown_entries(markdown_files=[markdownfile])
# Process Each Entry from All Notes Files
jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entries, entry_to_file_map))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
assert len(jsonl_data) == 1
def test_multiple_markdown_entries_to_jsonl(tmp_path):
"Convert multiple markdown entries from single file to jsonl."
# Arrange
entry = f'''
### Heading 1
\t\r
Heading 1 Body Line 1
### Heading 2
\t\r
Heading 2 Body Line 2
'''
markdownfile = create_file(tmp_path, entry)
# Act
# Extract Entries from specified Markdown files
entries, entry_to_file_map = extract_markdown_entries(markdown_files=[markdownfile])
# Process Each Entry from All Notes Files
jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entries, entry_to_file_map))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
assert len(jsonl_data) == 2
def test_get_markdown_files(tmp_path):
"Ensure Markdown files specified via input-filter, input-files extracted"
# Arrange
# Include via input-filter globs
group1_file1 = create_file(tmp_path, filename="group1-file1.md")
group1_file2 = create_file(tmp_path, filename="group1-file2.md")
group2_file1 = create_file(tmp_path, filename="group2-file1.markdown")
group2_file2 = create_file(tmp_path, filename="group2-file2.markdown")
# Include via input-file field
file1 = create_file(tmp_path, filename="notes.md")
# Not included by any filter
create_file(tmp_path, filename="not-included-markdown.md")
create_file(tmp_path, filename="not-included-text.txt")
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
# Setup input-files, input-filters
input_files = [tmp_path / 'notes.md']
input_filter = [tmp_path / 'group1*.md', tmp_path / 'group2*.markdown']
# Act
extracted_org_files = get_markdown_files(input_files, input_filter)
# Assert
assert len(extracted_org_files) == 5
assert extracted_org_files == expected_files
# Helper Functions
def create_file(tmp_path, entry=None, filename="test.md"):
markdown_file = tmp_path / filename
markdown_file.touch()
if entry:
markdown_file.write_text(entry)
return markdown_file

View file

@ -1,32 +1,37 @@
# Standard Packages
import json
from posixpath import split
# Internal Packages
from src.processor.org_mode.org_to_jsonl import convert_org_entries_to_jsonl, extract_org_entries
from src.processor.org_mode.org_to_jsonl import convert_org_entries_to_jsonl, convert_org_nodes_to_entries, extract_org_entries, get_org_files
from src.utils.helpers import is_none_or_empty
def test_entry_with_empty_body_line_to_jsonl(tmp_path):
'''Ensure entries with empty body are ignored.
def test_configure_heading_entry_to_jsonl(tmp_path):
'''Ensure entries with empty body are ignored, unless explicitly configured to index heading entries.
Property drawers not considered Body. Ignore control characters for evaluating if Body empty.'''
# Arrange
entry = f'''*** Heading
:PROPERTIES:
:ID: 42-42-42
:END:
\t\r\n
\t \r
'''
orgfile = create_file(tmp_path, entry)
for index_heading_entries in [True, False]:
# Act
# Extract Entries from specified Org files
entries = extract_org_entries(org_files=[orgfile])
# Process Each Entry from All Notes Files
jsonl_data = convert_org_entries_to_jsonl(entries)
# Extract entries into jsonl from specified Org files
jsonl_string = convert_org_entries_to_jsonl(convert_org_nodes_to_entries(
*extract_org_entries(org_files=[orgfile]),
index_heading_entries=index_heading_entries))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
if index_heading_entries:
# Entry with empty body indexed when index_heading_entries set to True
assert len(jsonl_data) == 1
else:
# Entry with empty body ignored when index_heading_entries set to False
assert is_none_or_empty(jsonl_data)
@ -37,15 +42,38 @@ def test_entry_with_body_to_jsonl(tmp_path):
:PROPERTIES:
:ID: 42-42-42
:END:
\t\r\nBody Line 1\n
\t\r
Body Line 1
'''
orgfile = create_file(tmp_path, entry)
# Act
# Extract Entries from specified Org files
entries = extract_org_entries(org_files=[orgfile])
entries, entry_to_file_map = extract_org_entries(org_files=[orgfile])
# Process Each Entry from All Notes Files
jsonl_string = convert_org_entries_to_jsonl(convert_org_nodes_to_entries(entries, entry_to_file_map))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
assert len(jsonl_data) == 1
def test_file_with_no_headings_to_jsonl(tmp_path):
"Ensure files with no heading, only body text are loaded."
# Arrange
entry = f'''
- Bullet point 1
- Bullet point 2
'''
orgfile = create_file(tmp_path, entry)
# Act
# Extract Entries from specified Org files
entry_nodes, file_to_entries = extract_org_entries(org_files=[orgfile])
# Process Each Entry from All Notes Files
entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries)
jsonl_string = convert_org_entries_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@ -53,10 +81,38 @@ def test_entry_with_body_to_jsonl(tmp_path):
assert len(jsonl_data) == 1
def test_get_org_files(tmp_path):
"Ensure Org files specified via input-filter, input-files extracted"
# Arrange
# Include via input-filter globs
group1_file1 = create_file(tmp_path, filename="group1-file1.org")
group1_file2 = create_file(tmp_path, filename="group1-file2.org")
group2_file1 = create_file(tmp_path, filename="group2-file1.org")
group2_file2 = create_file(tmp_path, filename="group2-file2.org")
# Include via input-file field
orgfile1 = create_file(tmp_path, filename="orgfile1.org")
# Not included by any filter
create_file(tmp_path, filename="orgfile2.org")
create_file(tmp_path, filename="text1.txt")
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1]))
# Setup input-files, input-filters
input_files = [tmp_path / 'orgfile1.org']
input_filter = [tmp_path / 'group1*.org', tmp_path / 'group2*.org']
# Act
extracted_org_files = get_org_files(input_files, input_filter)
# Assert
assert len(extracted_org_files) == 5
assert extracted_org_files == expected_files
# Helper Functions
def create_file(tmp_path, entry, filename="test.org"):
org_file = tmp_path / f"notes/{filename}"
org_file.parent.mkdir()
def create_file(tmp_path, entry=None, filename="test.org"):
org_file = tmp_path / filename
org_file.touch()
if entry:
org_file.write_text(entry)
return org_file

View file

@ -1,13 +1,33 @@
# Standard Packages
import datetime
from os.path import relpath
from pathlib import Path
# Internal Packages
from src.processor.org_mode import orgnode
# Test
# ----------------------------------------------------------------------------------------------------
def test_parse_entry_with_no_headings(tmp_path):
"Test parsing of entry with minimal fields"
# Arrange
entry = f'''Body Line 1'''
orgfile = create_file(tmp_path, entry)
# Act
entries = orgnode.makelist(orgfile)
# Assert
assert len(entries) == 1
assert entries[0].heading == f'{orgfile}'
assert entries[0].tags == list()
assert entries[0].body == "Body Line 1"
assert entries[0].priority == ""
assert entries[0].Property("ID") == ""
assert entries[0].closed == ""
assert entries[0].scheduled == ""
assert entries[0].deadline == ""
# ----------------------------------------------------------------------------------------------------
def test_parse_minimal_entry(tmp_path):
"Test parsing of entry with minimal fields"
@ -22,14 +42,14 @@ Body Line 1'''
# Assert
assert len(entries) == 1
assert entries[0].Heading() == "Heading"
assert entries[0].Tags() == set()
assert entries[0].Body() == "Body Line 1"
assert entries[0].Priority() == ""
assert entries[0].heading == "Heading"
assert entries[0].tags == list()
assert entries[0].body == "Body Line 1"
assert entries[0].priority == ""
assert entries[0].Property("ID") == ""
assert entries[0].Closed() == ""
assert entries[0].Scheduled() == ""
assert entries[0].Deadline() == ""
assert entries[0].closed == ""
assert entries[0].scheduled == ""
assert entries[0].deadline == ""
# ----------------------------------------------------------------------------------------------------
@ -55,16 +75,44 @@ Body Line 2'''
# Assert
assert len(entries) == 1
assert entries[0].Heading() == "Heading"
assert entries[0].Todo() == "DONE"
assert entries[0].Tags() == {"Tag1", "TAG2", "tag3"}
assert entries[0].Body() == "- Clocked Log 1\nBody Line 1\nBody Line 2"
assert entries[0].Priority() == "A"
assert entries[0].heading == "Heading"
assert entries[0].todo == "DONE"
assert entries[0].tags == ["Tag1", "TAG2", "tag3"]
assert entries[0].body == "- Clocked Log 1\nBody Line 1\nBody Line 2"
assert entries[0].priority == "A"
assert entries[0].Property("ID") == "id:123-456-789-4234-1231"
assert entries[0].Closed() == datetime.date(1984,4,1)
assert entries[0].Scheduled() == datetime.date(1984,4,1)
assert entries[0].Deadline() == datetime.date(1984,4,1)
assert entries[0].Logbook() == [(datetime.datetime(1984,4,1,9,0,0), datetime.datetime(1984,4,1,12,0,0))]
assert entries[0].closed == datetime.date(1984,4,1)
assert entries[0].scheduled == datetime.date(1984,4,1)
assert entries[0].deadline == datetime.date(1984,4,1)
assert entries[0].logbook == [(datetime.datetime(1984,4,1,9,0,0), datetime.datetime(1984,4,1,12,0,0))]
# ----------------------------------------------------------------------------------------------------
def test_render_entry_with_property_drawer_and_empty_body(tmp_path):
"Render heading entry with property drawer"
# Arrange
entry_to_render = f'''
*** [#A] Heading1 :tag1:
:PROPERTIES:
:ID: 111-111-111-1111-1111
:END:
\t\r \n
'''
orgfile = create_file(tmp_path, entry_to_render)
expected_entry = f'''*** [#A] Heading1 :tag1:
:PROPERTIES:
:LINE: file:{orgfile}::2
:ID: id:111-111-111-1111-1111
:SOURCE: [[file:{orgfile}::*Heading1]]
:END:
'''
# Act
parsed_entries = orgnode.makelist(orgfile)
# Assert
assert f'{parsed_entries[0]}' == expected_entry
# ----------------------------------------------------------------------------------------------------
@ -81,18 +129,17 @@ Body Line 1
Body Line 2
'''
orgfile = create_file(tmp_path, entry)
normalized_orgfile = f'~/{relpath(orgfile, start=Path.home())}'
# Act
entries = orgnode.makelist(orgfile)
# Assert
# SOURCE link rendered with Heading
assert f':SOURCE: [[file:{normalized_orgfile}::*{entries[0].Heading()}]]' in f'{entries[0]}'
assert f':SOURCE: [[file:{orgfile}::*{entries[0].heading}]]' in f'{entries[0]}'
# ID link rendered with ID
assert f':ID: id:123-456-789-4234-1231' in f'{entries[0]}'
# LINE link rendered with line number
assert f':LINE: file:{normalized_orgfile}::2' in f'{entries[0]}'
assert f':LINE: file:{orgfile}::2' in f'{entries[0]}'
# ----------------------------------------------------------------------------------------------------
@ -113,10 +160,9 @@ Body Line 1'''
# Assert
assert len(entries) == 1
# parsed heading from entry
assert entries[0].Heading() == "Heading[1]"
assert entries[0].heading == "Heading[1]"
# ensure SOURCE link has square brackets in filename, heading escaped in rendered entries
normalized_orgfile = f'~/{relpath(orgfile, start=Path.home())}'
escaped_orgfile = f'{normalized_orgfile}'.replace("[1]", "\\[1\\]")
escaped_orgfile = f'{orgfile}'.replace("[1]", "\\[1\\]")
assert f':SOURCE: [[file:{escaped_orgfile}::*Heading\[1\]' in f'{entries[0]}'
@ -156,16 +202,86 @@ Body 2
# Assert
assert len(entries) == 2
for index, entry in enumerate(entries):
assert entry.Heading() == f"Heading{index+1}"
assert entry.Todo() == "FAILED" if index == 0 else "CANCELLED"
assert entry.Tags() == {f"tag{index+1}"}
assert entry.Body() == f"- Clocked Log {index+1}\nBody {index+1}\n\n"
assert entry.Priority() == "A"
assert entry.heading == f"Heading{index+1}"
assert entry.todo == "FAILED" if index == 0 else "CANCELLED"
assert entry.tags == [f"tag{index+1}"]
assert entry.body == f"- Clocked Log {index+1}\nBody {index+1}\n\n"
assert entry.priority == "A"
assert entry.Property("ID") == f"id:123-456-789-4234-000{index+1}"
assert entry.Closed() == datetime.date(1984,4,index+1)
assert entry.Scheduled() == datetime.date(1984,4,index+1)
assert entry.Deadline() == datetime.date(1984,4,index+1)
assert entry.Logbook() == [(datetime.datetime(1984,4,index+1,9,0,0), datetime.datetime(1984,4,index+1,12,0,0))]
assert entry.closed == datetime.date(1984,4,index+1)
assert entry.scheduled == datetime.date(1984,4,index+1)
assert entry.deadline == datetime.date(1984,4,index+1)
assert entry.logbook == [(datetime.datetime(1984,4,index+1,9,0,0), datetime.datetime(1984,4,index+1,12,0,0))]
# ----------------------------------------------------------------------------------------------------
def test_parse_entry_with_empty_title(tmp_path):
"Test parsing of entry with minimal fields"
# Arrange
entry = f'''#+TITLE:
Body Line 1'''
orgfile = create_file(tmp_path, entry)
# Act
entries = orgnode.makelist(orgfile)
# Assert
assert len(entries) == 1
assert entries[0].heading == f'{orgfile}'
assert entries[0].tags == list()
assert entries[0].body == "Body Line 1"
assert entries[0].priority == ""
assert entries[0].Property("ID") == ""
assert entries[0].closed == ""
assert entries[0].scheduled == ""
assert entries[0].deadline == ""
# ----------------------------------------------------------------------------------------------------
def test_parse_entry_with_title_and_no_headings(tmp_path):
"Test parsing of entry with minimal fields"
# Arrange
entry = f'''#+TITLE: test
Body Line 1'''
orgfile = create_file(tmp_path, entry)
# Act
entries = orgnode.makelist(orgfile)
# Assert
assert len(entries) == 1
assert entries[0].heading == 'test'
assert entries[0].tags == list()
assert entries[0].body == "Body Line 1"
assert entries[0].priority == ""
assert entries[0].Property("ID") == ""
assert entries[0].closed == ""
assert entries[0].scheduled == ""
assert entries[0].deadline == ""
# ----------------------------------------------------------------------------------------------------
def test_parse_entry_with_multiple_titles_and_no_headings(tmp_path):
"Test parsing of entry with minimal fields"
# Arrange
entry = f'''#+TITLE: title1
Body Line 1
#+TITLE: title2 '''
orgfile = create_file(tmp_path, entry)
# Act
entries = orgnode.makelist(orgfile)
# Assert
assert len(entries) == 1
assert entries[0].heading == 'title1 title2'
assert entries[0].tags == list()
assert entries[0].body == "Body Line 1\n"
assert entries[0].priority == ""
assert entries[0].Property("ID") == ""
assert entries[0].closed == ""
assert entries[0].scheduled == ""
assert entries[0].deadline == ""
# Helper Functions

View file

@ -1,6 +1,10 @@
# System Packages
from copy import deepcopy
from pathlib import Path
# External Packages
import pytest
# Internal Packages
from src.utils.state import model
from src.search_type import text_search
@ -9,6 +13,39 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl
# Test
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup_with_missing_file_raises_error(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
file_to_index = Path(content_config.org.input_filter[0]).parent / "new_file_to_index.org"
new_org_content_config = deepcopy(content_config.org)
new_org_content_config.input_files = [f'{file_to_index}']
new_org_content_config.input_filter = None
# Act
# Generate notes embeddings during asymmetric setup
with pytest.raises(FileNotFoundError):
text_search.setup(org_to_jsonl, new_org_content_config, search_config.asymmetric, regenerate=True)
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup_with_empty_file_raises_error(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
file_to_index = Path(content_config.org.input_filter[0]).parent / "new_file_to_index.org"
file_to_index.touch()
new_org_content_config = deepcopy(content_config.org)
new_org_content_config.input_files = [f'{file_to_index}']
new_org_content_config.input_filter = None
# Act
# Generate notes embeddings during asymmetric setup
with pytest.raises(ValueError, match=r'^No valid entries found*'):
text_search.setup(org_to_jsonl, new_org_content_config, search_config.asymmetric, regenerate=True)
# Cleanup
# delete created test file
file_to_index.unlink()
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
# Act
@ -23,7 +60,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
query = "How to git install application?"
# Act
@ -51,7 +88,7 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
file_to_add_on_reload = Path(content_config.org.input_filter).parent / "reload.org"
file_to_add_on_reload = Path(content_config.org.input_filter[0]).parent / "reload.org"
content_config.org.input_files = [f'{file_to_add_on_reload}']
# append Org-Mode Entry to first Org Input File in Config
@ -77,3 +114,32 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
# delete reload test file added
content_config.org.input_files = []
file_to_add_on_reload.unlink()
# ----------------------------------------------------------------------------------------------------
def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
file_to_add_on_update = Path(content_config.org.input_filter[0]).parent / "update.org"
content_config.org.input_files = [f'{file_to_add_on_update}']
# append Org-Mode Entry to first Org Input File in Config
with open(file_to_add_on_update, "w") as f:
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
# Act
# update embeddings, entries with the newly added note
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
# verify new entry added in updated embeddings, entries
assert len(initial_notes_model.entries) == 11
assert len(initial_notes_model.corpus_embeddings) == 11
# Cleanup
# delete file added for update testing
content_config.org.input_files = []
file_to_add_on_update.unlink()

77
tests/test_word_filter.py Normal file
View file

@ -0,0 +1,77 @@
# Application Packages
from src.search_filter.word_filter import WordFilter
from src.utils.config import SearchType
def test_no_word_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
q_with_no_filter = 'head tail'
# Act
can_filter = word_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == False
assert ret_query == 'head tail'
assert entry_indices == {0, 1, 2, 3}
def test_word_exclude_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
q_with_exclude_filter = 'head -"exclude_word" tail'
# Act
can_filter = word_filter.can_filter(q_with_exclude_filter)
ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {0, 2}
def test_word_include_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
query_with_include_filter = 'head +"include_word" tail'
# Act
can_filter = word_filter.can_filter(query_with_include_filter)
ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {2, 3}
def test_word_include_and_exclude_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
# Act
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {2}
def arrange_content():
entries = [
{'compiled': '', 'raw': 'Minimal Entry'},
{'compiled': '', 'raw': 'Entry with exclude_word'},
{'compiled': '', 'raw': 'Entry with include_word'},
{'compiled': '', 'raw': 'Entry with include_word and exclude_word'}]
return entries