diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ad56e8ac..b8c2cf43 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -11,6 +11,10 @@ on: - Dockerfile - docker-compose.yml - .github/workflows/build.yml + workflow_dispatch: + +env: + DOCKER_IMAGE_TAG: ${{ github.ref == 'refs/heads/master' && 'latest' || github.ref }} jobs: build: @@ -36,6 +40,6 @@ jobs: context: . file: Dockerfile push: true - tags: ghcr.io/${{ github.repository }}:latest + tags: ghcr.io/${{ github.repository }}:${{ env.DOCKER_IMAGE_TAG }} build-args: | PORT=8000 \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a664f12e..9d642f2b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,15 +1,15 @@ name: release on: + push: + tags: + - v* workflow_dispatch: inputs: version: description: 'Version Number' required: true type: string - push: - tags: - - v* jobs: publish: diff --git a/.gitignore b/.gitignore index 9d33a849..a2e89c26 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ src/.data /dist/ /khoj_assistant.egg-info/ /config/khoj*.yml +.pytest_cache diff --git a/Dockerfile b/Dockerfile index 88febebe..2607bf07 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,7 @@ LABEL org.opencontainers.image.source https://github.com/debanjum/khoj # Install System Dependencies RUN apt-get update -y && \ - apt-get -y install libimage-exiftool-perl + apt-get -y install libimage-exiftool-perl python3-pyqt5 # Copy Application to Container COPY . /app diff --git a/Readme.md b/Readme.md index 76182560..c47e07a5 100644 --- a/Readme.md +++ b/Readme.md @@ -2,7 +2,6 @@ [![build](https://github.com/debanjum/khoj/actions/workflows/build.yml/badge.svg)](https://github.com/debanjum/khoj/actions/workflows/build.yml) [![test](https://github.com/debanjum/khoj/actions/workflows/test.yml/badge.svg)](https://github.com/debanjum/khoj/actions/workflows/test.yml) [![publish](https://github.com/debanjum/khoj/actions/workflows/publish.yml/badge.svg)](https://github.com/debanjum/khoj/actions/workflows/publish.yml) -[![release](https://github.com/debanjum/khoj/actions/workflows/release.yml/badge.svg)](https://github.com/debanjum/khoj/actions/workflows/release.yml) *A natural language search engine for your personal notes, transactions and images* @@ -107,7 +106,7 @@ pip install --upgrade khoj-assistant ## Troubleshoot - Symptom: Errors out complaining about Tensors mismatch, null etc - - Mitigation: Disable `image` search on the desktop GUI + - Mitigation: Disable `image` search using the desktop GUI - Symptom: Errors out with \"Killed\" in error message in Docker - Fix: Increase RAM available to Docker Containers in Docker Settings - Refer: [StackOverflow Solution](https://stackoverflow.com/a/50770267), [Configure Resources on Docker for Mac](https://docs.docker.com/desktop/mac/#resources) @@ -125,14 +124,14 @@ pip install --upgrade khoj-assistant - Semantic search using the bi-encoder is fairly fast at \<50 ms - Reranking using the cross-encoder is slower at \<2s on 15 results. Tweak `top_k` to tradeoff speed for accuracy of results -- Applying explicit filters is very slow currently at \~6s. This is because the filters are rudimentary. Considerable speed-ups can be achieved using indexes etc +- Filters in query (e.g by file, word or date) usually add \<20ms to query latency ### Indexing performance - Indexing is more strongly impacted by the size of the source data -- Indexing 100K+ line corpus of notes takes 6 minutes +- Indexing 100K+ line corpus of notes takes about 10 minutes - Indexing 4000+ images takes about 15 minutes and more than 8Gb of RAM -- Once is implemented, it should only take this long on first run +- Note: *It should only take this long on the first run* as the index is incrementally updated ### Miscellaneous diff --git a/docker-compose.yml b/docker-compose.yml index 3961f5c8..aebb7c55 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -28,4 +28,4 @@ services: - ./tests/data/embeddings/:/data/embeddings/ - ./tests/data/models/:/data/models/ # Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/ - command: config/khoj_docker.yml --host="0.0.0.0" --port=8000 -vv + command: --no-gui -c=config/khoj_docker.yml --host="0.0.0.0" --port=8000 -vv diff --git a/setup.py b/setup.py index bf9f3ce9..3de0fb75 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ this_directory = Path(__file__).parent setup( name='khoj-assistant', - version='0.1.6', + version='0.1.10', description="A natural language search engine for your personal notes, transactions and images", long_description=(this_directory / "Readme.md").read_text(encoding="utf-8"), long_description_content_type="text/markdown", diff --git a/src/configure.py b/src/configure.py index 1a86d251..0006c9fe 100644 --- a/src/configure.py +++ b/src/configure.py @@ -1,8 +1,8 @@ # System Packages import sys +import logging # External Packages -import torch import json # Internal Packages @@ -13,8 +13,14 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl from src.search_type import image_search, text_search from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from src.utils import state -from src.utils.helpers import get_absolute_path +from src.utils.helpers import LRU, resolve_absolute_path from src.utils.rawconfig import FullConfig, ProcessorConfig +from src.search_filter.date_filter import DateFilter +from src.search_filter.word_filter import WordFilter +from src.search_filter.file_filter import FileFilter + + +logger = logging.getLogger(__name__) def configure_server(args, required=False): @@ -28,27 +34,42 @@ def configure_server(args, required=False): state.config = args.config # Initialize the search model from Config - state.model = configure_search(state.model, state.config, args.regenerate, verbose=state.verbose) + state.model = configure_search(state.model, state.config, args.regenerate) # Initialize Processor from Config - state.processor_config = configure_processor(args.config.processor, verbose=state.verbose) + state.processor_config = configure_processor(args.config.processor) -def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None, verbose: int = 0): +def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None): # Initialize Org Notes Search if (t == SearchType.Org or t == None) and config.content_type.org: # Extract Entries, Generate Notes Embeddings - model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) + model.orgmode_search = text_search.setup( + org_to_jsonl, + config.content_type.org, + search_config=config.search_type.asymmetric, + regenerate=regenerate, + filters=[DateFilter(), WordFilter(), FileFilter()]) # Initialize Org Music Search if (t == SearchType.Music or t == None) and config.content_type.music: # Extract Entries, Generate Music Embeddings - model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) + model.music_search = text_search.setup( + org_to_jsonl, + config.content_type.music, + search_config=config.search_type.asymmetric, + regenerate=regenerate, + filters=[DateFilter(), WordFilter()]) # Initialize Markdown Search if (t == SearchType.Markdown or t == None) and config.content_type.markdown: # Extract Entries, Generate Markdown Embeddings - model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) + model.markdown_search = text_search.setup( + markdown_to_jsonl, + config.content_type.markdown, + search_config=config.search_type.asymmetric, + regenerate=regenerate, + filters=[DateFilter(), WordFilter(), FileFilter()]) # Initialize Panchayat Search if (t == SearchType.Panchayat or t == None) and config.content_type.panchayat: @@ -58,17 +79,28 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, # Initialize Ledger Search if (t == SearchType.Ledger or t == None) and config.content_type.ledger: # Extract Entries, Generate Ledger Embeddings - model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose) + model.ledger_search = text_search.setup( + beancount_to_jsonl, + config.content_type.ledger, + search_config=config.search_type.symmetric, + regenerate=regenerate, + filters=[DateFilter(), WordFilter(), FileFilter()]) # Initialize Image Search if (t == SearchType.Image or t == None) and config.content_type.image: # Extract Entries, Generate Image Embeddings - model.image_search = image_search.setup(config.content_type.image, search_config=config.search_type.image, regenerate=regenerate, verbose=verbose) + model.image_search = image_search.setup( + config.content_type.image, + search_config=config.search_type.image, + regenerate=regenerate) + + # Invalidate Query Cache + state.query_cache = LRU() return model -def configure_processor(processor_config: ProcessorConfig, verbose: int): +def configure_processor(processor_config: ProcessorConfig): if not processor_config: return @@ -76,27 +108,23 @@ def configure_processor(processor_config: ProcessorConfig, verbose: int): # Initialize Conversation Processor if processor_config.conversation: - processor.conversation = configure_conversation_processor(processor_config.conversation, verbose) + processor.conversation = configure_conversation_processor(processor_config.conversation) return processor -def configure_conversation_processor(conversation_processor_config, verbose: int): - conversation_processor = ConversationProcessorConfigModel(conversation_processor_config, verbose) +def configure_conversation_processor(conversation_processor_config): + conversation_processor = ConversationProcessorConfigModel(conversation_processor_config) + conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile) - conversation_logfile = conversation_processor.conversation_logfile - if conversation_processor.verbose: - print('INFO:\tLoading conversation logs from disk...') - - if conversation_logfile.expanduser().absolute().is_file(): + if conversation_logfile.is_file(): # Load Metadata Logs from Conversation Logfile - with open(get_absolute_path(conversation_logfile), 'r') as f: + with conversation_logfile.open('r') as f: conversation_processor.meta_log = json.load(f) - - print('INFO:\tConversation logs loaded from disk.') + logger.info('Conversation logs loaded from disk.') else: # Initialize Conversation Logs conversation_processor.meta_log = {} conversation_processor.chat_session = "" - return conversation_processor \ No newline at end of file + return conversation_processor diff --git a/src/interface/desktop/main_window.py b/src/interface/desktop/main_window.py index 97bbb54a..25287d27 100644 --- a/src/interface/desktop/main_window.py +++ b/src/interface/desktop/main_window.py @@ -92,7 +92,7 @@ class MainWindow(QtWidgets.QMainWindow): search_type_layout = QtWidgets.QVBoxLayout(search_type_settings) enable_search_type = SearchCheckBox(f"Search {search_type.name}", search_type) # Add file browser to set input files for given search type - input_files = FileBrowser(file_input_text, search_type, current_content_files) + input_files = FileBrowser(file_input_text, search_type, current_content_files or []) # Set enabled/disabled based on checkbox state enable_search_type.setChecked(current_content_files is not None and len(current_content_files) > 0) diff --git a/src/interface/desktop/system_tray.py b/src/interface/desktop/system_tray.py index 7b230c54..26df260f 100644 --- a/src/interface/desktop/system_tray.py +++ b/src/interface/desktop/system_tray.py @@ -6,9 +6,10 @@ from PyQt6 import QtGui, QtWidgets # Internal Packages from src.utils import constants, state +from src.interface.desktop.main_window import MainWindow -def create_system_tray(gui: QtWidgets.QApplication, main_window: QtWidgets.QMainWindow): +def create_system_tray(gui: QtWidgets.QApplication, main_window: MainWindow): """Create System Tray with Menu. Menu contain options to 1. Open Search Page on the Web Interface 2. Open App Configuration Screen diff --git a/src/interface/emacs/khoj.el b/src/interface/emacs/khoj.el index 5ab00d0f..c34fb88e 100644 --- a/src/interface/emacs/khoj.el +++ b/src/interface/emacs/khoj.el @@ -5,7 +5,7 @@ ;; Author: Debanjum Singh Solanky ;; Description: Natural, Incremental Search for your Second Brain ;; Keywords: search, org-mode, outlines, markdown, beancount, ledger, image -;; Version: 0.1.6 +;; Version: 0.1.9 ;; Package-Requires: ((emacs "27.1")) ;; URL: http://github.com/debanjum/khoj/interface/emacs diff --git a/src/interface/web/index.html b/src/interface/web/index.html index bd03feab..4bf3e82e 100644 --- a/src/interface/web/index.html +++ b/src/interface/web/index.html @@ -61,12 +61,26 @@ } function search(rerank=false) { - query = document.getElementById("query").value; + // Extract required fields for search from form + query = document.getElementById("query").value.trim(); type = document.getElementById("type").value; - console.log(query, type); + results_count = document.getElementById("results-count").value || 6; + console.log(`Query: ${query}, Type: ${type}`); + + // Short circuit on empty query + if (query.length === 0) + return; + + // If set query field in url query param on rerank + if (rerank) + setQueryFieldInUrl(query); + + // Generate Backend API URL to execute Search url = type === "image" - ? `/search?q=${query}&t=${type}&n=6` - : `/search?q=${query}&t=${type}&n=6&r=${rerank}`; + ? `/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}` + : `/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&r=${rerank}`; + + // Execute Search and Render Results fetch(url) .then(response => response.json()) .then(data => { @@ -78,9 +92,9 @@ }); } - function regenerate() { + function updateIndex() { type = document.getElementById("type").value; - fetch(`/regenerate?t=${type}`) + fetch(`/reload?t=${type}`) .then(response => response.json()) .then(data => { console.log(data); @@ -89,7 +103,7 @@ }); } - function incremental_search(event) { + function incrementalSearch(event) { type = document.getElementById("type").value; // Search with reranking on 'Enter' if (event.key === 'Enter') { @@ -121,10 +135,33 @@ }); } + function setTypeFieldInUrl(type) { + var url = new URL(window.location.href); + url.searchParams.set("t", type.value); + window.history.pushState({}, "", url.href); + } + + function setCountFieldInUrl(results_count) { + var url = new URL(window.location.href); + url.searchParams.set("n", results_count.value); + window.history.pushState({}, "", url.href); + } + + function setQueryFieldInUrl(query) { + var url = new URL(window.location.href); + url.searchParams.set("q", query); + window.history.pushState({}, "", url.href); + } + window.onload = function () { // Dynamically populate type dropdown based on enabled search types and type passed as URL query parameter populate_type_dropdown(); + // Set results count field with value passed in URL query parameters, if any. + var results_count = new URLSearchParams(window.location.search).get("n"); + if (results_count) + document.getElementById("results-count").value = results_count; + // Fill query field with value passed in URL query parameters, if any. var query_via_url = new URLSearchParams(window.location.search).get("q"); if (query_via_url) @@ -136,15 +173,18 @@

Khoj

- +
- + - -
+ + + + +
@@ -194,7 +234,7 @@ #options { padding: 0; display: grid; - grid-template-columns: 1fr 1fr; + grid-template-columns: 1fr 1fr minmax(70px, 0.5fr); } #options > * { padding: 15px; @@ -202,10 +242,10 @@ border: 1px solid #ccc; } #options > select { - margin-right: 5px; + margin-right: 10px; } #options > button { - margin-left: 5px; + margin-right: 10px; } #query { diff --git a/src/main.py b/src/main.py index 88a98c37..dc3b12c3 100644 --- a/src/main.py +++ b/src/main.py @@ -2,8 +2,13 @@ import os import signal import sys +import logging +import warnings from platform import system +# Ignore non-actionable warnings +warnings.filterwarnings("ignore", message=r'snapshot_download.py has been made private', category=FutureWarning) + # External Packages import uvicorn from fastapi import FastAPI @@ -25,6 +30,34 @@ app = FastAPI() app.mount("/static", StaticFiles(directory=constants.web_directory), name="static") app.include_router(router) +logger = logging.getLogger('src') + + +class CustomFormatter(logging.Formatter): + + blue = "\x1b[1;34m" + green = "\x1b[1;32m" + grey = "\x1b[38;20m" + yellow = "\x1b[33;20m" + red = "\x1b[31;20m" + bold_red = "\x1b[31;1m" + reset = "\x1b[0m" + format_str = "%(levelname)s: %(asctime)s: %(name)s | %(message)s" + + FORMATS = { + logging.DEBUG: blue + format_str + reset, + logging.INFO: green + format_str + reset, + logging.WARNING: yellow + format_str + reset, + logging.ERROR: red + format_str + reset, + logging.CRITICAL: bold_red + format_str + reset + } + + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt) + return formatter.format(record) + + def run(): # Turn Tokenizers Parallelism Off. App does not support it. os.environ["TOKENIZERS_PARALLELISM"] = 'false' @@ -34,6 +67,29 @@ def run(): args = cli(state.cli_args) set_state(args) + # Create app directory, if it doesn't exist + state.config_file.parent.mkdir(parents=True, exist_ok=True) + + # Setup Logger + if args.verbose == 0: + logger.setLevel(logging.WARN) + elif args.verbose == 1: + logger.setLevel(logging.INFO) + elif args.verbose >= 2: + logger.setLevel(logging.DEBUG) + + # Set Log Format + ch = logging.StreamHandler() + ch.setFormatter(CustomFormatter()) + logger.addHandler(ch) + + # Set Log File + fh = logging.FileHandler(state.config_file.parent / 'khoj.log') + fh.setLevel(logging.DEBUG) + logger.addHandler(fh) + + logger.info("Starting Khoj...") + if args.no_gui: # Start Server configure_server(args, required=True) diff --git a/src/processor/ledger/beancount_to_jsonl.py b/src/processor/ledger/beancount_to_jsonl.py index 861f8620..7b8b9bba 100644 --- a/src/processor/ledger/beancount_to_jsonl.py +++ b/src/processor/ledger/beancount_to_jsonl.py @@ -2,108 +2,127 @@ # Standard Packages import json -import argparse -import pathlib import glob import re +import logging +import time # Internal Packages -from src.utils.helpers import get_absolute_path, is_none_or_empty +from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update from src.utils.constants import empty_escape_sequences from src.utils.jsonl import dump_jsonl, compress_jsonl_data +from src.utils.rawconfig import TextContentConfig + + +logger = logging.getLogger(__name__) # Define Functions -def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file, verbose=0): +def beancount_to_jsonl(config: TextContentConfig, previous_entries=None): + # Extract required fields from config + beancount_files, beancount_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl + # Input Validation if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter): print("At least one of beancount-files or beancount-file-filter is required to be specified") exit(1) # Get Beancount Files to Process - beancount_files = get_beancount_files(beancount_files, beancount_file_filter, verbose) + beancount_files = get_beancount_files(beancount_files, beancount_file_filter) # Extract Entries from specified Beancount files - entries = extract_beancount_entries(beancount_files) + start = time.time() + current_entries = convert_transactions_to_maps(*extract_beancount_transactions(beancount_files)) + end = time.time() + logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds") + + # Identify, mark and merge any new entries with previous entries + start = time.time() + if not previous_entries: + entries_with_ids = list(enumerate(current_entries)) + else: + entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) + end = time.time() + logger.debug(f"Identify new or updated transaction: {end - start} seconds") # Process Each Entry from All Notes Files - jsonl_data = convert_beancount_entries_to_jsonl(entries, verbose=verbose) + start = time.time() + entries = list(map(lambda entry: entry[1], entries_with_ids)) + jsonl_data = convert_transaction_maps_to_jsonl(entries) # Compress JSONL formatted Data if output_file.suffix == ".gz": - compress_jsonl_data(jsonl_data, output_file, verbose=verbose) + compress_jsonl_data(jsonl_data, output_file) elif output_file.suffix == ".jsonl": - dump_jsonl(jsonl_data, output_file, verbose=verbose) + dump_jsonl(jsonl_data, output_file) + end = time.time() + logger.debug(f"Write transactions to JSONL file: {end - start} seconds") - return entries + return entries_with_ids -def get_beancount_files(beancount_files=None, beancount_file_filter=None, verbose=0): +def get_beancount_files(beancount_files=None, beancount_file_filters=None): "Get Beancount files to process" absolute_beancount_files, filtered_beancount_files = set(), set() if beancount_files: absolute_beancount_files = {get_absolute_path(beancount_file) for beancount_file in beancount_files} - if beancount_file_filter: - filtered_beancount_files = set(glob.glob(get_absolute_path(beancount_file_filter))) + if beancount_file_filters: + filtered_beancount_files = { + filtered_file + for beancount_file_filter in beancount_file_filters + for filtered_file in glob.glob(get_absolute_path(beancount_file_filter)) + } - all_beancount_files = absolute_beancount_files | filtered_beancount_files + all_beancount_files = sorted(absolute_beancount_files | filtered_beancount_files) - files_with_non_beancount_extensions = {beancount_file - for beancount_file - in all_beancount_files - if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount")} + files_with_non_beancount_extensions = { + beancount_file + for beancount_file + in all_beancount_files + if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount") + } if any(files_with_non_beancount_extensions): print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}") - if verbose > 0: - print(f'Processing files: {all_beancount_files}') + logger.info(f'Processing files: {all_beancount_files}') return all_beancount_files -def extract_beancount_entries(beancount_files): +def extract_beancount_transactions(beancount_files): "Extract entries from specified Beancount files" # Initialize Regex for extracting Beancount Entries transaction_regex = r'^\n?\d{4}-\d{2}-\d{2} [\*|\!] ' - empty_newline = f'^[{empty_escape_sequences}]*$' + empty_newline = f'^[\n\r\t\ ]*$' entries = [] + transaction_to_file_map = [] for beancount_file in beancount_files: with open(beancount_file) as f: ledger_content = f.read() - entries.extend([entry.strip(empty_escape_sequences) + transactions_per_file = [entry.strip(empty_escape_sequences) for entry in re.split(empty_newline, ledger_content, flags=re.MULTILINE) - if re.match(transaction_regex, entry)]) - - return entries + if re.match(transaction_regex, entry)] + transaction_to_file_map += zip(transactions_per_file, [beancount_file]*len(transactions_per_file)) + entries.extend(transactions_per_file) + return entries, dict(transaction_to_file_map) -def convert_beancount_entries_to_jsonl(entries, verbose=0): - "Convert each Beancount transaction to JSON and collate as JSONL" - jsonl = '' +def convert_transactions_to_maps(entries: list[str], transaction_to_file_map) -> list[dict]: + "Convert each Beancount transaction into a dictionary" + entry_maps = [] for entry in entries: - entry_dict = {'compiled': entry, 'raw': entry} - # Convert Dictionary to JSON and Append to JSONL string - jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n' + entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{transaction_to_file_map[entry]}'}) - if verbose > 0: - print(f"Converted {len(entries)} to jsonl format") + logger.info(f"Converted {len(entries)} transactions to dictionaries") - return jsonl + return entry_maps -if __name__ == '__main__': - # Setup Argument Parser - parser = argparse.ArgumentParser(description="Map Beancount transactions into (compressed) JSONL format") - parser.add_argument('--output-file', '-o', type=pathlib.Path, required=True, help="Output file for (compressed) JSONL formatted transactions. Expected file extensions: jsonl or jsonl.gz") - parser.add_argument('--input-files', '-i', nargs='*', help="List of beancount files to process") - parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for beancount files to process") - parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs, Default: 0") - args = parser.parse_args() - - # Map transactions in beancount files to (compressed) JSONL formatted file - beancount_to_jsonl(args.input_files, args.input_filter, args.output_file, args.verbose) +def convert_transaction_maps_to_jsonl(entries: list[dict]) -> str: + "Convert each Beancount transaction dictionary to JSON and collate as JSONL" + return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) diff --git a/src/processor/markdown/markdown_to_jsonl.py b/src/processor/markdown/markdown_to_jsonl.py index c2133ede..22f5ea17 100644 --- a/src/processor/markdown/markdown_to_jsonl.py +++ b/src/processor/markdown/markdown_to_jsonl.py @@ -2,51 +2,78 @@ # Standard Packages import json -import argparse -import pathlib import glob import re +import logging +import time # Internal Packages -from src.utils.helpers import get_absolute_path, is_none_or_empty +from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update from src.utils.constants import empty_escape_sequences from src.utils.jsonl import dump_jsonl, compress_jsonl_data +from src.utils.rawconfig import TextContentConfig + + +logger = logging.getLogger(__name__) # Define Functions -def markdown_to_jsonl(markdown_files, markdown_file_filter, output_file, verbose=0): +def markdown_to_jsonl(config: TextContentConfig, previous_entries=None): + # Extract required fields from config + markdown_files, markdown_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl + # Input Validation if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter): print("At least one of markdown-files or markdown-file-filter is required to be specified") exit(1) # Get Markdown Files to Process - markdown_files = get_markdown_files(markdown_files, markdown_file_filter, verbose) + markdown_files = get_markdown_files(markdown_files, markdown_file_filter) # Extract Entries from specified Markdown files - entries = extract_markdown_entries(markdown_files) + start = time.time() + current_entries = convert_markdown_entries_to_maps(*extract_markdown_entries(markdown_files)) + end = time.time() + logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds") + + # Identify, mark and merge any new entries with previous entries + start = time.time() + if not previous_entries: + entries_with_ids = list(enumerate(current_entries)) + else: + entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) + end = time.time() + logger.debug(f"Identify new or updated entries: {end - start} seconds") # Process Each Entry from All Notes Files - jsonl_data = convert_markdown_entries_to_jsonl(entries, verbose=verbose) + start = time.time() + entries = list(map(lambda entry: entry[1], entries_with_ids)) + jsonl_data = convert_markdown_maps_to_jsonl(entries) # Compress JSONL formatted Data if output_file.suffix == ".gz": - compress_jsonl_data(jsonl_data, output_file, verbose=verbose) + compress_jsonl_data(jsonl_data, output_file) elif output_file.suffix == ".jsonl": - dump_jsonl(jsonl_data, output_file, verbose=verbose) + dump_jsonl(jsonl_data, output_file) + end = time.time() + logger.debug(f"Write markdown entries to JSONL file: {end - start} seconds") - return entries + return entries_with_ids -def get_markdown_files(markdown_files=None, markdown_file_filter=None, verbose=0): +def get_markdown_files(markdown_files=None, markdown_file_filters=None): "Get Markdown files to process" absolute_markdown_files, filtered_markdown_files = set(), set() if markdown_files: absolute_markdown_files = {get_absolute_path(markdown_file) for markdown_file in markdown_files} - if markdown_file_filter: - filtered_markdown_files = set(glob.glob(get_absolute_path(markdown_file_filter))) + if markdown_file_filters: + filtered_markdown_files = { + filtered_file + for markdown_file_filter in markdown_file_filters + for filtered_file in glob.glob(get_absolute_path(markdown_file_filter)) + } - all_markdown_files = absolute_markdown_files | filtered_markdown_files + all_markdown_files = sorted(absolute_markdown_files | filtered_markdown_files) files_with_non_markdown_extensions = { md_file @@ -56,10 +83,9 @@ def get_markdown_files(markdown_files=None, markdown_file_filter=None, verbose=0 } if any(files_with_non_markdown_extensions): - print(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}") + logger.warn(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}") - if verbose > 0: - print(f'Processing files: {all_markdown_files}') + logger.info(f'Processing files: {all_markdown_files}') return all_markdown_files @@ -71,38 +97,31 @@ def extract_markdown_entries(markdown_files): markdown_heading_regex = r'^#' entries = [] + entry_to_file_map = [] for markdown_file in markdown_files: with open(markdown_file) as f: markdown_content = f.read() - entries.extend([f'#{entry.strip(empty_escape_sequences)}' + markdown_entries_per_file = [f'#{entry.strip(empty_escape_sequences)}' for entry - in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE)]) + in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE) + if entry.strip(empty_escape_sequences) != ''] + entry_to_file_map += zip(markdown_entries_per_file, [markdown_file]*len(markdown_entries_per_file)) + entries.extend(markdown_entries_per_file) - return entries + return entries, dict(entry_to_file_map) -def convert_markdown_entries_to_jsonl(entries, verbose=0): - "Convert each Markdown entries to JSON and collate as JSONL" - jsonl = '' +def convert_markdown_entries_to_maps(entries: list[str], entry_to_file_map) -> list[dict]: + "Convert each Markdown entries into a dictionary" + entry_maps = [] for entry in entries: - entry_dict = {'compiled': entry, 'raw': entry} - # Convert Dictionary to JSON and Append to JSONL string - jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n' + entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{entry_to_file_map[entry]}'}) - if verbose > 0: - print(f"Converted {len(entries)} to jsonl format") + logger.info(f"Converted {len(entries)} markdown entries to dictionaries") - return jsonl + return entry_maps -if __name__ == '__main__': - # Setup Argument Parser - parser = argparse.ArgumentParser(description="Map Markdown entries into (compressed) JSONL format") - parser.add_argument('--output-file', '-o', type=pathlib.Path, required=True, help="Output file for (compressed) JSONL formatted notes. Expected file extensions: jsonl or jsonl.gz") - parser.add_argument('--input-files', '-i', nargs='*', help="List of markdown files to process") - parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for markdown files to process") - parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs, Default: 0") - args = parser.parse_args() - - # Map notes in Markdown files to (compressed) JSONL formatted file - markdown_to_jsonl(args.input_files, args.input_filter, args.output_file, args.verbose) +def convert_markdown_maps_to_jsonl(entries): + "Convert each Markdown entries to JSON and collate as JSONL" + return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) diff --git a/src/processor/org_mode/org_to_jsonl.py b/src/processor/org_mode/org_to_jsonl.py index ea2962f5..43f4acef 100644 --- a/src/processor/org_mode/org_to_jsonl.py +++ b/src/processor/org_mode/org_to_jsonl.py @@ -1,62 +1,94 @@ #!/usr/bin/env python3 # Standard Packages -import re import json -import argparse -import pathlib import glob +import logging +import time +from typing import Iterable # Internal Packages from src.processor.org_mode import orgnode -from src.utils.helpers import get_absolute_path, is_none_or_empty -from src.utils.constants import empty_escape_sequences +from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update from src.utils.jsonl import dump_jsonl, compress_jsonl_data +from src.utils import state +from src.utils.rawconfig import TextContentConfig + + +logger = logging.getLogger(__name__) # Define Functions -def org_to_jsonl(org_files, org_file_filter, output_file, verbose=0): +def org_to_jsonl(config: TextContentConfig, previous_entries=None): + # Extract required fields from config + org_files, org_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl + index_heading_entries = config.index_heading_entries + # Input Validation if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter): print("At least one of org-files or org-file-filter is required to be specified") exit(1) # Get Org Files to Process - org_files = get_org_files(org_files, org_file_filter, verbose) + start = time.time() + org_files = get_org_files(org_files, org_file_filter) # Extract Entries from specified Org files - entries = extract_org_entries(org_files) + start = time.time() + entry_nodes, file_to_entries = extract_org_entries(org_files) + end = time.time() + logger.debug(f"Parse entries from org files into OrgNode objects: {end - start} seconds") + + start = time.time() + current_entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries) + end = time.time() + logger.debug(f"Convert OrgNodes into entry dictionaries: {end - start} seconds") + + # Identify, mark and merge any new entries with previous entries + if not previous_entries: + entries_with_ids = list(enumerate(current_entries)) + else: + entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) # Process Each Entry from All Notes Files - jsonl_data = convert_org_entries_to_jsonl(entries, verbose=verbose) + start = time.time() + entries = map(lambda entry: entry[1], entries_with_ids) + jsonl_data = convert_org_entries_to_jsonl(entries) # Compress JSONL formatted Data if output_file.suffix == ".gz": - compress_jsonl_data(jsonl_data, output_file, verbose=verbose) + compress_jsonl_data(jsonl_data, output_file) elif output_file.suffix == ".jsonl": - dump_jsonl(jsonl_data, output_file, verbose=verbose) + dump_jsonl(jsonl_data, output_file) + end = time.time() + logger.debug(f"Write org entries to JSONL file: {end - start} seconds") - return entries + return entries_with_ids -def get_org_files(org_files=None, org_file_filter=None, verbose=0): +def get_org_files(org_files=None, org_file_filters=None): "Get Org files to process" absolute_org_files, filtered_org_files = set(), set() if org_files: - absolute_org_files = {get_absolute_path(org_file) - for org_file - in org_files} - if org_file_filter: - filtered_org_files = set(glob.glob(get_absolute_path(org_file_filter))) + absolute_org_files = { + get_absolute_path(org_file) + for org_file + in org_files + } + if org_file_filters: + filtered_org_files = { + filtered_file + for org_file_filter in org_file_filters + for filtered_file in glob.glob(get_absolute_path(org_file_filter)) + } - all_org_files = absolute_org_files | filtered_org_files + all_org_files = sorted(absolute_org_files | filtered_org_files) files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")} if any(files_with_non_org_extensions): - print(f"[Warning] There maybe non org-mode files in the input set: {files_with_non_org_extensions}") + logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}") - if verbose > 0: - print(f'Processing files: {all_org_files}') + logger.info(f'Processing files: {all_org_files}') return all_org_files @@ -64,69 +96,60 @@ def get_org_files(org_files=None, org_file_filter=None, verbose=0): def extract_org_entries(org_files): "Extract entries from specified Org files" entries = [] + entry_to_file_map = [] for org_file in org_files: - entries.extend( - orgnode.makelist( - str(org_file))) + org_file_entries = orgnode.makelist(str(org_file)) + entry_to_file_map += zip(org_file_entries, [org_file]*len(org_file_entries)) + entries.extend(org_file_entries) - return entries + return entries, dict(entry_to_file_map) -def convert_org_entries_to_jsonl(entries, verbose=0) -> str: - "Convert each Org-Mode entries to JSON and collate as JSONL" - jsonl = '' +def convert_org_nodes_to_entries(entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[dict]: + "Convert Org-Mode entries into list of dictionary" + entry_maps = [] for entry in entries: entry_dict = dict() - # Ignore title notes i.e notes with just headings and empty body - if not entry.Body() or re.sub(r'\n|\t|\r| ', '', entry.Body()) == "": + if not entry.hasBody and not index_heading_entries: + # Ignore title notes i.e notes with just headings and empty body continue - entry_dict["compiled"] = f'{entry.Heading()}.' - if verbose > 2: - print(f"Title: {entry.Heading()}") + entry_dict["compiled"] = f'{entry.heading}.' + if state.verbose > 2: + logger.debug(f"Title: {entry.heading}") - if entry.Tags(): - tags_str = " ".join(entry.Tags()) + if entry.tags: + tags_str = " ".join(entry.tags) entry_dict["compiled"] += f'\t {tags_str}.' - if verbose > 2: - print(f"Tags: {tags_str}") + if state.verbose > 2: + logger.debug(f"Tags: {tags_str}") - if entry.Closed(): - entry_dict["compiled"] += f'\n Closed on {entry.Closed().strftime("%Y-%m-%d")}.' - if verbose > 2: - print(f'Closed: {entry.Closed().strftime("%Y-%m-%d")}') + if entry.closed: + entry_dict["compiled"] += f'\n Closed on {entry.closed.strftime("%Y-%m-%d")}.' + if state.verbose > 2: + logger.debug(f'Closed: {entry.closed.strftime("%Y-%m-%d")}') - if entry.Scheduled(): - entry_dict["compiled"] += f'\n Scheduled for {entry.Scheduled().strftime("%Y-%m-%d")}.' - if verbose > 2: - print(f'Scheduled: {entry.Scheduled().strftime("%Y-%m-%d")}') + if entry.scheduled: + entry_dict["compiled"] += f'\n Scheduled for {entry.scheduled.strftime("%Y-%m-%d")}.' + if state.verbose > 2: + logger.debug(f'Scheduled: {entry.scheduled.strftime("%Y-%m-%d")}') - if entry.Body(): - entry_dict["compiled"] += f'\n {entry.Body()}' - if verbose > 2: - print(f"Body: {entry.Body()}") + if entry.hasBody: + entry_dict["compiled"] += f'\n {entry.body}' + if state.verbose > 2: + logger.debug(f"Body: {entry.body}") if entry_dict: entry_dict["raw"] = f'{entry}' + entry_dict["file"] = f'{entry_to_file_map[entry]}' # Convert Dictionary to JSON and Append to JSONL string - jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n' + entry_maps.append(entry_dict) - if verbose > 0: - print(f"Converted {len(entries)} to jsonl format") - - return jsonl + return entry_maps -if __name__ == '__main__': - # Setup Argument Parser - parser = argparse.ArgumentParser(description="Map Org-Mode notes into (compressed) JSONL format") - parser.add_argument('--output-file', '-o', type=pathlib.Path, required=True, help="Output file for (compressed) JSONL formatted notes. Expected file extensions: jsonl or jsonl.gz") - parser.add_argument('--input-files', '-i', nargs='*', help="List of org-mode files to process") - parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for org-mode files to process") - parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs, Default: 0") - args = parser.parse_args() - - # Map notes in Org-Mode files to (compressed) JSONL formatted file - org_to_jsonl(args.input_files, args.input_filter, args.output_file, args.verbose) +def convert_org_entries_to_jsonl(entries: Iterable[dict]) -> str: + "Convert each Org-Mode entry to JSON and collate as JSONL" + return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) diff --git a/src/processor/org_mode/orgnode.py b/src/processor/org_mode/orgnode.py index 39c67731..a5f4cd43 100644 --- a/src/processor/org_mode/orgnode.py +++ b/src/processor/org_mode/orgnode.py @@ -33,16 +33,21 @@ headline and associated text from an org-mode file, and routines for constructing data structures of these classes. """ -import re, sys +import re import datetime from pathlib import Path from os.path import relpath -indent_regex = re.compile(r'^\s*') +indent_regex = re.compile(r'^ *') def normalize_filename(filename): - file_relative_to_home = f'~/{relpath(filename, start=Path.home())}' - escaped_filename = f'{file_relative_to_home}'.replace("[","\[").replace("]","\]") + "Normalize and escape filename for rendering" + if not Path(filename).is_absolute(): + # Normalize relative filename to be relative to current directory + normalized_filename = f'~/{relpath(filename, start=Path.home())}' + else: + normalized_filename = filename + escaped_filename = f'{normalized_filename}'.replace("[","\[").replace("]","\]") return escaped_filename def makelist(filename): @@ -52,65 +57,71 @@ def makelist(filename): """ ctr = 0 - try: - f = open(filename, 'r') - except IOError: - print(f"Unable to open file {filename}") - print("Program terminating.") - sys.exit(1) + f = open(filename, 'r') todos = { "TODO": "", "WAITING": "", "ACTIVE": "", "DONE": "", "CANCELLED": "", "FAILED": ""} # populated from #+SEQ_TODO line - level = 0 + level = "" heading = "" bodytext = "" - tags = set() # set of all tags in headline + tags = list() # set of all tags in headline closed_date = '' sched_date = '' deadline_date = '' logbook = list() - nodelist = [] - propdict = dict() + nodelist: list[Orgnode] = list() + property_map = dict() in_properties_drawer = False in_logbook_drawer = False + file_title = f'{filename}' for line in f: ctr += 1 - hdng = re.search(r'^(\*+)\s(.*?)\s*$', line) - if hdng: # we are processing a heading line + heading_search = re.search(r'^(\*+)\s(.*?)\s*$', line) + if heading_search: # we are processing a heading line if heading: # if we have are on second heading, append first heading to headings list thisNode = Orgnode(level, heading, bodytext, tags) if closed_date: - thisNode.setClosed(closed_date) + thisNode.closed = closed_date closed_date = '' if sched_date: - thisNode.setScheduled(sched_date) + thisNode.scheduled = sched_date sched_date = "" if deadline_date: - thisNode.setDeadline(deadline_date) + thisNode.deadline = deadline_date deadline_date = '' if logbook: - thisNode.setLogbook(logbook) + thisNode.logbook = logbook logbook = list() - thisNode.setProperties(propdict) + thisNode.properties = property_map nodelist.append( thisNode ) - propdict = {'LINE': f'file:{normalize_filename(filename)}::{ctr}'} - level = hdng.group(1) - heading = hdng.group(2) + property_map = {'LINE': f'file:{normalize_filename(filename)}::{ctr}'} + level = heading_search.group(1) + heading = heading_search.group(2) bodytext = "" - tags = set() # set of all tags in headline - tagsrch = re.search(r'(.*?)\s*:([a-zA-Z0-9].*?):$',heading) - if tagsrch: - heading = tagsrch.group(1) - parsedtags = tagsrch.group(2) + tags = list() # set of all tags in headline + tag_search = re.search(r'(.*?)\s*:([a-zA-Z0-9].*?):$',heading) + if tag_search: + heading = tag_search.group(1) + parsedtags = tag_search.group(2) if parsedtags: for parsedtag in parsedtags.split(':'): - if parsedtag != '': tags.add(parsedtag) + if parsedtag != '': tags.append(parsedtag) else: # we are processing a non-heading line if line[:10] == '#+SEQ_TODO': kwlist = re.findall(r'([A-Z]+)\(', line) for kw in kwlist: todos[kw] = "" + # Set file title to TITLE property, if it exists + title_search = re.search(r'^#\+TITLE:\s*(.*)$', line) + if title_search and title_search.group(1).strip() != '': + title_text = title_search.group(1).strip() + if file_title == f'{filename}': + file_title = title_text + else: + file_title += f' {title_text}' + continue + # Ignore Properties Drawers Completely if re.search(':PROPERTIES:', line): in_properties_drawer=True @@ -137,13 +148,13 @@ def makelist(filename): logbook += [(clocked_in, clocked_out)] line = "" - prop_srch = re.search(r'^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$', line) - if prop_srch: + property_search = re.search(r'^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$', line) + if property_search: # Set ID property to an id based org-mode link to the entry - if prop_srch.group(1) == 'ID': - propdict['ID'] = f'id:{prop_srch.group(2)}' + if property_search.group(1) == 'ID': + property_map['ID'] = f'id:{property_search.group(2)}' else: - propdict[prop_srch.group(1)] = prop_srch.group(2) + property_map[property_search.group(1)] = property_search.group(2) continue cd_re = re.search(r'CLOSED:\s*\[([0-9]{4})-([0-9]{2})-([0-9]{2})', line) @@ -167,37 +178,40 @@ def makelist(filename): bodytext = bodytext + line # write out last node - thisNode = Orgnode(level, heading, bodytext, tags) - thisNode.setProperties(propdict) + thisNode = Orgnode(level, heading or file_title, bodytext, tags) + thisNode.properties = property_map if sched_date: - thisNode.setScheduled(sched_date) + thisNode.scheduled = sched_date if deadline_date: - thisNode.setDeadline(deadline_date) + thisNode.deadline = deadline_date if closed_date: - thisNode.setClosed(closed_date) + thisNode.closed = closed_date if logbook: - thisNode.setLogbook(logbook) + thisNode.logbook = logbook nodelist.append( thisNode ) # using the list of TODO keywords found in the file # process the headings searching for TODO keywords for n in nodelist: - h = n.Heading() - todoSrch = re.search(r'([A-Z]+)\s(.*?)$', h) - if todoSrch: - if todoSrch.group(1) in todos: - n.setHeading( todoSrch.group(2) ) - n.setTodo ( todoSrch.group(1) ) + todo_search = re.search(r'([A-Z]+)\s(.*?)$', n.heading) + if todo_search: + if todo_search.group(1) in todos: + n.heading = todo_search.group(2) + n.todo = todo_search.group(1) # extract, set priority from heading, update heading if necessary - prtysrch = re.search(r'^\[\#(A|B|C)\] (.*?)$', n.Heading()) - if prtysrch: - n.setPriority(prtysrch.group(1)) - n.setHeading(prtysrch.group(2)) + priority_search = re.search(r'^\[\#(A|B|C)\] (.*?)$', n.heading) + if priority_search: + n.priority = priority_search.group(1) + n.heading = priority_search.group(2) # Set SOURCE property to a file+heading based org-mode link to the entry - escaped_heading = n.Heading().replace("[","\\[").replace("]","\\]") - n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}::*{escaped_heading}]]' + if n.level == 0: + n.properties['LINE'] = f'file:{normalize_filename(filename)}::0' + n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}]]' + else: + escaped_heading = n.heading.replace("[","\\[").replace("]","\\]") + n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}::*{escaped_heading}]]' return nodelist @@ -214,199 +228,234 @@ class Orgnode(object): first tag. The makelist routine postprocesses the list to identify TODO tags and updates headline and todo fields. """ - self.level = len(level) - self.headline = headline - self.body = body - self.tags = set(tags) # All tags in the headline - self.todo = "" - self.prty = "" # empty of A, B or C - self.scheduled = "" # Scheduled date - self.deadline = "" # Deadline date - self.closed = "" # Closed date - self.properties = dict() - self.logbook = list() # List of clock-in, clock-out tuples representing logbook entries + self._level = len(level) + self._heading = headline + self._body = body + self._tags = tags # All tags in the headline + self._todo = "" + self._priority = "" # empty of A, B or C + self._scheduled = "" # Scheduled date + self._deadline = "" # Deadline date + self._closed = "" # Closed date + self._properties = dict() + self._logbook = list() # List of clock-in, clock-out tuples representing logbook entries # Look for priority in headline and transfer to prty field - def Heading(self): + @property + def heading(self): """ Return the Heading text of the node without the TODO tag """ - return self.headline + return self._heading - def setHeading(self, newhdng): + @heading.setter + def heading(self, newhdng): """ Change the heading to the supplied string """ - self.headline = newhdng + self._heading = newhdng - def Body(self): + @property + def body(self): """ Returns all lines of text of the body of this node except the Property Drawer """ - return self.body + return self._body - def Level(self): + @property + def hasBody(self): + """ + Returns True if node has non empty body, else False + """ + return self._body and re.sub(r'\n|\t|\r| ', '', self._body) != '' + + @property + def level(self): """ Returns an integer corresponding to the level of the node. Top level (one asterisk) has a level of 1. """ - return self.level + return self._level - def Priority(self): + @property + def priority(self): """ Returns the priority of this headline: 'A', 'B', 'C' or empty string if priority has not been set. """ - return self.prty + return self._priority - def setPriority(self, newprty): + @priority.setter + def priority(self, new_priority): """ Change the value of the priority of this headline. Values values are '', 'A', 'B', 'C' """ - self.prty = newprty + self._priority = new_priority - def Tags(self): + @property + def tags(self): """ - Returns the set of all tags - For example, :HOME:COMPUTER: would return {'HOME', 'COMPUTER'} + Returns the list of all tags + For example, :HOME:COMPUTER: would return ['HOME', 'COMPUTER'] """ - return self.tags + return self._tags - def hasTag(self, srch): + @tags.setter + def tags(self, newtags): + """ + Store all the tags found in the headline. + """ + self._tags = newtags + + def hasTag(self, tag): """ Returns True if the supplied tag is present in this headline For example, hasTag('COMPUTER') on headling containing :HOME:COMPUTER: would return True. """ - return srch in self.tags + return tag in self._tags - def setTags(self, newtags): - """ - Store all the tags found in the headline. - """ - self.tags = set(newtags) - - def Todo(self): + @property + def todo(self): """ Return the value of the TODO tag """ - return self.todo + return self._todo - def setTodo(self, value): + @todo.setter + def todo(self, new_todo): """ Set the value of the TODO tag to the supplied string """ - self.todo = value + self._todo = new_todo - def setProperties(self, dictval): + @property + def properties(self): + """ + Return the dictionary of properties + """ + return self._properties + + @properties.setter + def properties(self, new_properties): """ Sets all properties using the supplied dictionary of name/value pairs """ - self.properties = dictval + self._properties = new_properties - def Property(self, keyval): + def Property(self, property_key): """ Returns the value of the requested property or null if the property does not exist. """ - return self.properties.get(keyval, "") + return self._properties.get(property_key, "") - def setScheduled(self, dateval): + @property + def scheduled(self): """ - Set the scheduled date using the supplied date object + Return the scheduled date """ - self.scheduled = dateval + return self._scheduled - def Scheduled(self): + @scheduled.setter + def scheduled(self, new_scheduled): """ - Return the scheduled date object or null if nonexistent + Set the scheduled date to the scheduled date """ - return self.scheduled + self._scheduled = new_scheduled - def setDeadline(self, dateval): + @property + def deadline(self): """ - Set the deadline (due) date using the supplied date object + Return the deadline date """ - self.deadline = dateval + return self._deadline - def Deadline(self): + @deadline.setter + def deadline(self, new_deadline): """ - Return the deadline date object or null if nonexistent + Set the deadline (due) date to the new deadline date """ - return self.deadline + self._deadline = new_deadline - def setClosed(self, dateval): + @property + def closed(self): """ - Set the closed date using the supplied date object + Return the closed date """ - self.closed = dateval + return self._closed - def Closed(self): + @closed.setter + def closed(self, new_closed): """ - Return the closed date object or null if nonexistent + Set the closed date to the new closed date """ - return self.closed + self._closed = new_closed - def setLogbook(self, logbook): - """ - Set the logbook with list of clocked-in, clocked-out tuples for the entry - """ - self.logbook = logbook - - def Logbook(self): + @property + def logbook(self): """ Return the logbook with all clocked-in, clocked-out date object pairs or empty list if nonexistent """ - return self.logbook + return self._logbook + + @logbook.setter + def logbook(self, new_logbook): + """ + Set the logbook with list of clocked-in, clocked-out tuples for the entry + """ + self._logbook = new_logbook def __repr__(self): """ Print the level, heading text and tag of a node and the body text as used to construct the node. """ - # This method is not completed yet. + # Output heading line n = '' - for _ in range(0, self.level): + for _ in range(0, self._level): n = n + '*' n = n + ' ' - if self.todo: - n = n + self.todo + ' ' - if self.prty: - n = n + '[#' + self.prty + '] ' - n = n + self.headline + if self._todo: + n = n + self._todo + ' ' + if self._priority: + n = n + '[#' + self._priority + '] ' + n = n + self._heading n = "%-60s " % n # hack - tags will start in column 62 closecolon = '' - for t in self.tags: + for t in self._tags: n = n + ':' + t closecolon = ':' n = n + closecolon n = n + "\n" # Get body indentation from first line of body - indent = indent_regex.match(self.body).group() + indent = indent_regex.match(self._body).group() # Output Closed Date, Scheduled Date, Deadline Date - if self.closed or self.scheduled or self.deadline: + if self._closed or self._scheduled or self._deadline: n = n + indent - if self.closed: - n = n + f'CLOSED: [{self.closed.strftime("%Y-%m-%d %a")}] ' - if self.scheduled: - n = n + f'SCHEDULED: <{self.scheduled.strftime("%Y-%m-%d %a")}> ' - if self.deadline: - n = n + f'DEADLINE: <{self.deadline.strftime("%Y-%m-%d %a")}> ' - if self.closed or self.scheduled or self.deadline: + if self._closed: + n = n + f'CLOSED: [{self._closed.strftime("%Y-%m-%d %a")}] ' + if self._scheduled: + n = n + f'SCHEDULED: <{self._scheduled.strftime("%Y-%m-%d %a")}> ' + if self._deadline: + n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> ' + if self._closed or self._scheduled or self._deadline: n = n + '\n' # Ouput Property Drawer n = n + indent + ":PROPERTIES:\n" - for key, value in self.properties.items(): + for key, value in self._properties.items(): n = n + indent + f":{key}: {value}\n" n = n + indent + ":END:\n" - n = n + self.body + # Output Body + if self.hasBody: + n = n + self._body return n diff --git a/src/router.py b/src/router.py index f768865b..c3e65e17 100644 --- a/src/router.py +++ b/src/router.py @@ -2,8 +2,8 @@ import yaml import json import time +import logging from typing import Optional -from functools import lru_cache # External Packages from fastapi import APIRouter @@ -15,16 +15,17 @@ from fastapi.templating import Jinja2Templates from src.configure import configure_search from src.search_type import image_search, text_search from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize -from src.search_filter.explicit_filter import ExplicitFilter -from src.search_filter.date_filter import DateFilter from src.utils.rawconfig import FullConfig from src.utils.config import SearchType -from src.utils.helpers import get_absolute_path, get_from_dict +from src.utils.helpers import LRU, get_absolute_path, get_from_dict from src.utils import state, constants -router = APIRouter() +router = APIRouter() templates = Jinja2Templates(directory=constants.web_directory) +logger = logging.getLogger(__name__) +query_cache = LRU() + @router.get("/", response_class=FileResponse) def index(): @@ -47,22 +48,27 @@ async def config_data(updated_config: FullConfig): return state.config @router.get('/search') -@lru_cache(maxsize=100) def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False): if q is None or q == '': - print(f'No query param (q) passed in API call to initiate search') + logger.info(f'No query param (q) passed in API call to initiate search') return {} # initialize variables - user_query = q + user_query = q.strip() results_count = n results = {} query_start, query_end, collate_start, collate_end = None, None, None, None + # return cached results, if available + query_cache_key = f'{user_query}-{n}-{t}-{r}' + if query_cache_key in state.query_cache: + logger.info(f'Return response from query cache') + return state.query_cache[query_cache_key] + if (t == SearchType.Org or t == None) and state.model.orgmode_search: # query org-mode notes query_start = time.time() - hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r) query_end = time.time() # collate and return results @@ -73,7 +79,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Music or t == None) and state.model.music_search: # query music library query_start = time.time() - hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r) query_end = time.time() # collate and return results @@ -84,7 +90,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Markdown or t == None) and state.model.markdown_search: # query markdown files query_start = time.time() - hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r) query_end = time.time() # collate and return results @@ -95,7 +101,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Ledger or t == None) and state.model.ledger_search: # query transactions query_start = time.time() - hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r) query_end = time.time() # collate and return results @@ -131,11 +137,13 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti count=results_count) collate_end = time.time() - if state.verbose > 1: - if query_start and query_end: - print(f"Query took {query_end - query_start:.3f} seconds") - if collate_start and collate_end: - print(f"Collating results took {collate_end - collate_start:.3f} seconds") + # Cache results + state.query_cache[query_cache_key] = results + + if query_start and query_end: + logger.debug(f"Query took {query_end - query_start:.3f} seconds") + if collate_start and collate_end: + logger.debug(f"Collating results took {collate_end - collate_start:.3f} seconds") return results diff --git a/src/search_filter/base_filter.py b/src/search_filter/base_filter.py new file mode 100644 index 00000000..2550b32e --- /dev/null +++ b/src/search_filter/base_filter.py @@ -0,0 +1,16 @@ +# Standard Packages +from abc import ABC, abstractmethod + + +class BaseFilter(ABC): + @abstractmethod + def load(self, *args, **kwargs): + pass + + @abstractmethod + def can_filter(self, raw_query:str) -> bool: + pass + + @abstractmethod + def apply(self, query:str, raw_entries:list[str]) -> tuple[str, set[int]]: + pass \ No newline at end of file diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index dc70ca29..22a66068 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -1,62 +1,98 @@ # Standard Packages import re +import time +import logging +from collections import defaultdict from datetime import timedelta, datetime -from dateutil.relativedelta import relativedelta, MO +from dateutil.relativedelta import relativedelta from math import inf # External Packages -import torch import dateparser as dtparse +# Internal Packages +from src.search_filter.base_filter import BaseFilter +from src.utils.helpers import LRU -class DateFilter: + +logger = logging.getLogger(__name__) + + +class DateFilter(BaseFilter): # Date Range Filter Regexes # Example filter queries: - # - dt>="yesterday" dt<"tomorrow" - # - dt>="last week" - # - dt:"2 years ago" + # - dt>="yesterday" dt<"tomorrow" + # - dt>="last week" + # - dt:"2 years ago" date_regex = r"dt([:><=]{1,2})\"(.*?)\"" + + def __init__(self, entry_key='raw'): + self.entry_key = entry_key + self.date_to_entry_ids = defaultdict(set) + self.cache = LRU() + + + def load(self, entries, **_): + start = time.time() + for id, entry in enumerate(entries): + # Extract dates from entry + for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]): + # Convert date string in entry to unix timestamp + try: + date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() + except ValueError: + continue + self.date_to_entry_ids[date_in_entry].add(id) + end = time.time() + logger.debug(f"Created date filter index: {end - start} seconds") + + def can_filter(self, raw_query): "Check if query contains date filters" return self.extract_date_range(raw_query) is not None - def filter(self, query, entries, embeddings, entry_key='raw'): + def apply(self, query, raw_entries): "Find entries containing any dates that fall within date range specified in query" # extract date range specified in date filter of query + start = time.time() query_daterange = self.extract_date_range(query) + end = time.time() + logger.debug(f"Extract date range to filter from query: {end - start} seconds") # if no date in query, return all entries if query_daterange is None: - return query, entries, embeddings + return query, set(range(len(raw_entries))) # remove date range filter from query query = re.sub(rf'\s+{self.date_regex}', ' ', query) query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces + # return results from cache if exists + cache_key = tuple(query_daterange) + if cache_key in self.cache: + logger.info(f"Return date filter results from cache") + entries_to_include = self.cache[cache_key] + return query, entries_to_include + + if not self.date_to_entry_ids: + self.load(raw_entries) + # find entries containing any dates that fall with date range specified in query + start = time.time() entries_to_include = set() - for id, entry in enumerate(entries): - # Extract dates from entry - for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]): - # Convert date string in entry to unix timestamp - try: - date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() - except ValueError: - continue - # Check if date in entry is within date range specified in query - if query_daterange[0] <= date_in_entry < query_daterange[1]: - entries_to_include.add(id) - break + for date_in_entry in self.date_to_entry_ids.keys(): + # Check if date in entry is within date range specified in query + if query_daterange[0] <= date_in_entry < query_daterange[1]: + entries_to_include |= self.date_to_entry_ids[date_in_entry] + end = time.time() + logger.debug(f"Mark entries satisfying filter: {end - start} seconds") - # delete entries (and their embeddings) marked for exclusion - entries_to_exclude = set(range(len(entries))) - entries_to_include - for id in sorted(list(entries_to_exclude), reverse=True): - del entries[id] - embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) + # cache results + self.cache[cache_key] = entries_to_include - return query, entries, embeddings + return query, entries_to_include def extract_date_range(self, query): diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py deleted file mode 100644 index b7bb6754..00000000 --- a/src/search_filter/explicit_filter.py +++ /dev/null @@ -1,57 +0,0 @@ -# Standard Packages -import re - -# External Packages -import torch - - -class ExplicitFilter: - def can_filter(self, raw_query): - "Check if query contains explicit filters" - # Extract explicit query portion with required, blocked words to filter from natural query - required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) - blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) - - return len(required_words) != 0 or len(blocked_words) != 0 - - - def filter(self, raw_query, entries, embeddings, entry_key='raw'): - "Find entries containing required and not blocked words specified in query" - # Separate natural query from explicit required, blocked words filters - query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) - required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) - blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) - - if len(required_words) == 0 and len(blocked_words) == 0: - return query, entries, embeddings - - # convert each entry to a set of words - # split on fullstop, comma, colon, tab, newline or any brackets - entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' - entries_by_word_set = [set(word.lower() - for word - in re.split(entry_splitter, entry[entry_key]) - if word != "") - for entry in entries] - - # track id of entries to exclude - entries_to_exclude = set() - - # mark entries that do not contain all required_words for exclusion - if len(required_words) > 0: - for id, words_in_entry in enumerate(entries_by_word_set): - if not required_words.issubset(words_in_entry): - entries_to_exclude.add(id) - - # mark entries that contain any blocked_words for exclusion - if len(blocked_words) > 0: - for id, words_in_entry in enumerate(entries_by_word_set): - if words_in_entry.intersection(blocked_words): - entries_to_exclude.add(id) - - # delete entries (and their embeddings) marked for exclusion - for id in sorted(list(entries_to_exclude), reverse=True): - del entries[id] - embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) - - return query, entries, embeddings diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py new file mode 100644 index 00000000..41f80274 --- /dev/null +++ b/src/search_filter/file_filter.py @@ -0,0 +1,79 @@ +# Standard Packages +import re +import fnmatch +import time +import logging +from collections import defaultdict + +# Internal Packages +from src.search_filter.base_filter import BaseFilter +from src.utils.helpers import LRU + + +logger = logging.getLogger(__name__) + + +class FileFilter(BaseFilter): + file_filter_regex = r'file:"(.+?)" ?' + + def __init__(self, entry_key='file'): + self.entry_key = entry_key + self.file_to_entry_map = defaultdict(set) + self.cache = LRU() + + def load(self, entries, *args, **kwargs): + start = time.time() + for id, entry in enumerate(entries): + self.file_to_entry_map[entry[self.entry_key]].add(id) + end = time.time() + logger.debug(f"Created file filter index: {end - start} seconds") + + def can_filter(self, raw_query): + return re.search(self.file_filter_regex, raw_query) is not None + + def apply(self, raw_query, raw_entries): + # Extract file filters from raw query + start = time.time() + raw_files_to_search = re.findall(self.file_filter_regex, raw_query) + if not raw_files_to_search: + return raw_query, set(range(len(raw_entries))) + + # Convert simple file filters with no path separator into regex + # e.g. "file:notes.org" -> "file:.*notes.org" + files_to_search = [] + for file in sorted(raw_files_to_search): + if '/' not in file and '\\' not in file and '*' not in file: + files_to_search += [f'*{file}'] + else: + files_to_search += [file] + end = time.time() + logger.debug(f"Extract files_to_search from query: {end - start} seconds") + + # Return item from cache if exists + query = re.sub(self.file_filter_regex, '', raw_query).strip() + cache_key = tuple(files_to_search) + if cache_key in self.cache: + logger.info(f"Return file filter results from cache") + included_entry_indices = self.cache[cache_key] + return query, included_entry_indices + + if not self.file_to_entry_map: + self.load(raw_entries, regenerate=False) + + # Mark entries that contain any blocked_words for exclusion + start = time.time() + + included_entry_indices = set.union(*[self.file_to_entry_map[entry_file] + for entry_file in self.file_to_entry_map.keys() + for search_file in files_to_search + if fnmatch.fnmatch(entry_file, search_file)], set()) + if not included_entry_indices: + return query, {} + + end = time.time() + logger.debug(f"Mark entries satisfying filter: {end - start} seconds") + + # Cache results + self.cache[cache_key] = included_entry_indices + + return query, included_entry_indices diff --git a/src/search_filter/word_filter.py b/src/search_filter/word_filter.py new file mode 100644 index 00000000..e040ceee --- /dev/null +++ b/src/search_filter/word_filter.py @@ -0,0 +1,96 @@ +# Standard Packages +import re +import time +import logging +from collections import defaultdict + +# Internal Packages +from src.search_filter.base_filter import BaseFilter +from src.utils.helpers import LRU + + +logger = logging.getLogger(__name__) + + +class WordFilter(BaseFilter): + # Filter Regex + required_regex = r'\+"([a-zA-Z0-9_-]+)" ?' + blocked_regex = r'\-"([a-zA-Z0-9_-]+)" ?' + + def __init__(self, entry_key='raw'): + self.entry_key = entry_key + self.word_to_entry_index = defaultdict(set) + self.cache = LRU() + + + def load(self, entries, regenerate=False): + start = time.time() + self.cache = {} # Clear cache on filter (re-)load + entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'' + # Create map of words to entries they exist in + for entry_index, entry in enumerate(entries): + for word in re.split(entry_splitter, entry[self.entry_key].lower()): + if word == '': + continue + self.word_to_entry_index[word].add(entry_index) + end = time.time() + logger.debug(f"Created word filter index: {end - start} seconds") + + return self.word_to_entry_index + + + def can_filter(self, raw_query): + "Check if query contains word filters" + required_words = re.findall(self.required_regex, raw_query) + blocked_words = re.findall(self.blocked_regex, raw_query) + + return len(required_words) != 0 or len(blocked_words) != 0 + + + def apply(self, raw_query, raw_entries): + "Find entries containing required and not blocked words specified in query" + # Separate natural query from required, blocked words filters + start = time.time() + + required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)]) + blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)]) + query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip() + + end = time.time() + logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") + + if len(required_words) == 0 and len(blocked_words) == 0: + return query, set(range(len(raw_entries))) + + # Return item from cache if exists + cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words)) + if cache_key in self.cache: + logger.info(f"Return word filter results from cache") + included_entry_indices = self.cache[cache_key] + return query, included_entry_indices + + if not self.word_to_entry_index: + self.load(raw_entries, regenerate=False) + + start = time.time() + + # mark entries that contain all required_words for inclusion + entries_with_all_required_words = set(range(len(raw_entries))) + if len(required_words) > 0: + entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words]) + + # mark entries that contain any blocked_words for exclusion + entries_with_any_blocked_words = set() + if len(blocked_words) > 0: + entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words]) + + end = time.time() + logger.debug(f"Mark entries satisfying filter: {end - start} seconds") + + # get entries satisfying inclusion and exclusion filters + included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words + + # Cache results + self.cache[cache_key] = included_entry_indices + + return query, included_entry_indices diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index b57d4f20..c9dcdd6b 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -1,9 +1,10 @@ # Standard Packages -import argparse import glob import pathlib import copy import shutil +import time +import logging # External Packages from sentence_transformers import SentenceTransformer, util @@ -18,6 +19,10 @@ from src.utils.config import ImageSearchModel from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig +# Create Logger +logger = logging.getLogger(__name__) + + def initialize_model(search_config: ImageSearchConfig): # Initialize Model torch.set_num_threads(4) @@ -37,41 +42,43 @@ def initialize_model(search_config: ImageSearchConfig): return encoder -def extract_entries(image_directories, verbose=0): +def extract_entries(image_directories): image_names = [] for image_directory in image_directories: image_directory = resolve_absolute_path(image_directory, strict=True) image_names.extend(list(image_directory.glob('*.jpg'))) image_names.extend(list(image_directory.glob('*.jpeg'))) - if verbose > 0: + if logger.level >= logging.INFO: image_directory_names = ', '.join([str(image_directory) for image_directory in image_directories]) - print(f'Found {len(image_names)} images in {image_directory_names}') + logger.info(f'Found {len(image_names)} images in {image_directory_names}') return sorted(image_names) -def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0): +def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" - image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate, verbose) - image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate, verbose) + image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate) + image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate) return image_embeddings, image_metadata_embeddings -def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False, verbose=0): - image_embeddings = None - +def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False): # Load pre-computed image embeddings from file if exists if resolve_absolute_path(embeddings_file).exists() and not regenerate: image_embeddings = torch.load(embeddings_file) - if verbose > 0: - print(f"Loaded pre-computed embeddings from {embeddings_file}") + logger.info(f"Loaded {len(image_embeddings)} image embeddings from {embeddings_file}") # Else compute the image embeddings from scratch, which can take a while - elif image_embeddings is None: + else: image_embeddings = [] for index in trange(0, len(image_names), batch_size): - images = [Image.open(image_name) for image_name in image_names[index:index+batch_size]] + images = [] + for image_name in image_names[index:index+batch_size]: + image = Image.open(image_name) + # Resize images to max width of 640px for faster processing + image.thumbnail((640, image.height)) + images += [image] image_embeddings += encoder.encode( images, convert_to_tensor=True, @@ -82,8 +89,7 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5 # Save computed image embeddings to file torch.save(image_embeddings, embeddings_file) - if verbose > 0: - print(f"Saved computed embeddings to {embeddings_file}") + logger.info(f"Saved computed embeddings to {embeddings_file}") return image_embeddings @@ -94,8 +100,7 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz # Load pre-computed image metadata embedding file if exists if use_xmp_metadata and resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate: image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata") - if verbose > 0: - print(f"Loaded pre-computed embeddings from {embeddings_file}_metadata") + logger.info(f"Loaded pre-computed embeddings from {embeddings_file}_metadata") # Else compute the image metadata embeddings from scratch, which can take a while if use_xmp_metadata and image_metadata_embeddings is None: @@ -108,16 +113,15 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz convert_to_tensor=True, batch_size=min(len(image_metadata), batch_size)) except RuntimeError as e: - print(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}") + logger.error(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}") continue torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata") - if verbose > 0: - print(f"Saved computed metadata embeddings to {embeddings_file}_metadata") + logger.info(f"Saved computed metadata embeddings to {embeddings_file}_metadata") return image_metadata_embeddings -def extract_metadata(image_name, verbose=0): +def extract_metadata(image_name): with exiftool.ExifTool() as et: image_metadata = et.get_tags(["XMP:Subject", "XMP:Description"], str(image_name)) image_metadata_subjects = set([subject.split(":")[1] for subject in image_metadata.get("XMP:Subject", "") if ":" in subject]) @@ -126,8 +130,7 @@ def extract_metadata(image_name, verbose=0): if len(image_metadata_subjects) > 0: image_processed_metadata += ". " + ", ".join(image_metadata_subjects) - if verbose > 2: - print(f"{image_name}:\t{image_processed_metadata}") + logger.debug(f"{image_name}:\t{image_processed_metadata}") return image_processed_metadata @@ -137,26 +140,34 @@ def query(raw_query, count, model: ImageSearchModel): if pathlib.Path(raw_query).is_file(): query_imagepath = resolve_absolute_path(pathlib.Path(raw_query), strict=True) query = copy.deepcopy(Image.open(query_imagepath)) - if model.verbose > 0: - print(f"Find Images similar to Image at {query_imagepath}") + query.thumbnail((640, query.height)) # scale down image for faster processing + logger.info(f"Find Images similar to Image at {query_imagepath}") else: query = raw_query - if model.verbose > 0: - print(f"Find Images by Text: {query}") + logger.info(f"Find Images by Text: {query}") # Now we encode the query (which can either be an image or a text string) + start = time.time() query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False) + end = time.time() + logger.debug(f"Query Encode Time: {end - start:.3f} seconds") # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings. + start = time.time() image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']} for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]} + end = time.time() + logger.debug(f"Search Time: {end - start:.3f} seconds") # Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings. if model.image_metadata_embeddings: + start = time.time() metadata_hits = {result['corpus_id']: result['score'] for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]} + end = time.time() + logger.debug(f"Metadata Search Time: {end - start:.3f} seconds") # Sum metadata, image scores of the highest ranked images for corpus_id, score in metadata_hits.items(): @@ -219,7 +230,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count= return results -def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel: +def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel: # Initialize Model encoder = initialize_model(search_config) @@ -227,9 +238,13 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera absolute_image_files, filtered_image_files = set(), set() if config.input_directories: image_directories = [resolve_absolute_path(directory, strict=True) for directory in config.input_directories] - absolute_image_files = set(extract_entries(image_directories, verbose)) + absolute_image_files = set(extract_entries(image_directories)) if config.input_filter: - filtered_image_files = set(glob.glob(get_absolute_path(config.input_filter))) + filtered_image_files = { + filtered_file + for input_filter in config.input_filter + for filtered_file in glob.glob(get_absolute_path(input_filter)) + } all_image_files = sorted(list(absolute_image_files | filtered_image_files)) @@ -241,38 +256,9 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera embeddings_file, batch_size=config.batch_size, regenerate=regenerate, - use_xmp_metadata=config.use_xmp_metadata, - verbose=verbose) + use_xmp_metadata=config.use_xmp_metadata) return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, - encoder, - verbose) - - -if __name__ == '__main__': - # Setup Argument Parser - parser = argparse.ArgumentParser(description="Semantic Search on Images") - parser.add_argument('--image-directory', '-i', required=True, type=pathlib.Path, help="Image directory to query") - parser.add_argument('--embeddings-file', '-e', default='image_embeddings.pt', type=pathlib.Path, help="File to save/load model embeddings to/from. Default: ./embeddings.pt") - parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings of Images in Image Directory . Default: false") - parser.add_argument('--results-count', '-n', default=5, type=int, help="Number of results to render. Default: 5") - parser.add_argument('--interactive', action='store_true', default=False, help="Interactive mode allows user to run queries on the model. Default: true") - parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0") - args = parser.parse_args() - - image_names, image_embeddings, image_metadata_embeddings, model = setup(args.image_directory, args.embeddings_file, regenerate=args.regenerate) - - # Run User Queries on Entries in Interactive Mode - while args.interactive: - # get query from user - user_query = input("Enter your query: ") - if user_query == "exit": - exit(0) - - # query images - hits = query(user_query, image_embeddings, image_metadata_embeddings, model, args.results_count, args.verbose) - - # render results - render_results(hits, image_names, args.image_directory, count=args.results_count) + encoder) diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 2b47eabe..d4d8a9d4 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -1,21 +1,23 @@ # Standard Packages -import argparse -import pathlib -from copy import deepcopy +import logging import time # External Packages import torch from sentence_transformers import SentenceTransformer, CrossEncoder, util +from src.search_filter.base_filter import BaseFilter # Internal Packages from src.utils import state -from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model +from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model from src.utils.config import TextSearchModel from src.utils.rawconfig import TextSearchConfig, TextContentConfig from src.utils.jsonl import load_jsonl +logger = logging.getLogger(__name__) + + def initialize_model(search_config: TextSearchConfig): "Initialize model for semantic search on text" torch.set_num_threads(4) @@ -46,73 +48,85 @@ def initialize_model(search_config: TextSearchConfig): return bi_encoder, cross_encoder, top_k -def extract_entries(jsonl_file, verbose=0): +def extract_entries(jsonl_file): "Load entries from compressed jsonl" - return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'} - for entry - in load_jsonl(jsonl_file, verbose=verbose)] + return load_jsonl(jsonl_file) -def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0): +def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate=False): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" - # Load pre-computed embeddings from file if exists + new_entries = [] + # Load pre-computed embeddings from file if exists and update them if required if embeddings_file.exists() and not regenerate: corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device) - if verbose > 0: - print(f"Loaded embeddings from {embeddings_file}") + logger.info(f"Loaded embeddings from {embeddings_file}") - else: # Else compute the corpus_embeddings from scratch, which can take a while - corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=state.device, show_progress_bar=True) + # Encode any new entries in the corpus and update corpus embeddings + new_entries = [entry['compiled'] for id, entry in entries_with_ids if id is None] + if new_entries: + new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) + existing_entry_ids = [id for id, _ in entries_with_ids if id is not None] + existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor() + corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) + # Else compute the corpus embeddings from scratch + else: + new_entries = [entry['compiled'] for _, entry in entries_with_ids] + corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) + + # Save regenerated or updated embeddings to file + if new_entries: corpus_embeddings = util.normalize_embeddings(corpus_embeddings) torch.save(corpus_embeddings, embeddings_file) - if verbose > 0: - print(f"Computed embeddings and saved them to {embeddings_file}") + logger.info(f"Computed embeddings and saved them to {embeddings_file}") return corpus_embeddings -def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = [], verbose=0): +def query(raw_query: str, model: TextSearchModel, rank_results=False): "Search for entries that answer the query" - query = raw_query - - # Use deep copy of original embeddings, entries to filter if query contains filters - start = time.time() - filters_in_query = [filter for filter in filters if filter.can_filter(query)] - if filters_in_query: - corpus_embeddings = deepcopy(model.corpus_embeddings) - entries = deepcopy(model.entries) - else: - corpus_embeddings = model.corpus_embeddings - entries = model.entries - end = time.time() - if verbose > 1: - print(f"Copy Time: {end - start:.3f} seconds") + query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings # Filter query, entries and embeddings before semantic search - start = time.time() + start_filter = time.time() + included_entry_indices = set(range(len(entries))) + filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] for filter in filters_in_query: - query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings) - end = time.time() - if verbose > 1: - print(f"Filter Time: {end - start:.3f} seconds") + query, included_entry_indices_by_filter = filter.apply(query, entries) + included_entry_indices.intersection_update(included_entry_indices_by_filter) + + # Get entries (and associated embeddings) satisfying all filters + if not included_entry_indices: + return [], [] + else: + start = time.time() + entries = [entries[id] for id in included_entry_indices] + corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices))) + end = time.time() + logger.debug(f"Keep entries satisfying all filters: {end - start} seconds") + + end_filter = time.time() + logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds") if entries is None or len(entries) == 0: return [], [] + # If query only had filters it'll be empty now. So short-circuit and return results. + if query.strip() == "": + hits = [{"corpus_id": id, "score": 1.0} for id, _ in enumerate(entries)] + return hits, entries + # Encode the query using the bi-encoder start = time.time() question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) question_embedding = util.normalize_embeddings(question_embedding) end = time.time() - if verbose > 1: - print(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}") + logger.debug(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}") # Find relevant entries for the query start = time.time() hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0] end = time.time() - if verbose > 1: - print(f"Search Time: {end - start:.3f} seconds on device: {state.device}") + logger.debug(f"Search Time: {end - start:.3f} seconds on device: {state.device}") # Score all retrieved entries using the cross-encoder if rank_results: @@ -120,8 +134,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits] cross_scores = model.cross_encoder.predict(cross_inp) end = time.time() - if verbose > 1: - print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}") + logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}") # Store cross-encoder scores in results dictionary for ranking for idx in range(len(cross_scores)): @@ -133,8 +146,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l if rank_results: hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score end = time.time() - if verbose > 1: - print(f"Rank Time: {end - start:.3f} seconds on device: {state.device}") + logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}") return hits, entries @@ -167,50 +179,26 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel: +def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config) # Map notes in text files to (compressed) JSONL formatted file config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) - if not config.compressed_jsonl.exists() or regenerate: - text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose) + previous_entries = extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None + entries_with_indices = text_to_jsonl(config, previous_entries) - # Extract Entries - entries = extract_entries(config.compressed_jsonl, verbose) + # Extract Updated Entries + entries = extract_entries(config.compressed_jsonl) + if is_none_or_empty(entries): + raise ValueError(f"No valid entries found in specified files: {config.input_files} or {config.input_filter}") top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus # Compute or Load Embeddings config.embeddings_file = resolve_absolute_path(config.embeddings_file) - corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose) + corpus_embeddings = compute_embeddings(entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate) - return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose) + for filter in filters: + filter.load(entries, regenerate=regenerate) - -if __name__ == '__main__': - # Setup Argument Parser - parser = argparse.ArgumentParser(description="Map Text files into (compressed) JSONL format") - parser.add_argument('--input-files', '-i', nargs='*', help="List of Text files to process") - parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for Text files to process") - parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path("text.jsonl.gz"), help="Compressed JSONL to compute embeddings from") - parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path("text_embeddings.pt"), help="File to save/load model embeddings to/from") - parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from text files. Default: false") - parser.add_argument('--results-count', '-n', default=5, type=int, help="Number of results to render. Default: 5") - parser.add_argument('--interactive', action='store_true', default=False, help="Interactive mode allows user to run queries on the model. Default: true") - parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0") - args = parser.parse_args() - - entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate, args.verbose) - - # Run User Queries on Entries in Interactive Mode - while args.interactive: - # get query from user - user_query = input("Enter your query: ") - if user_query == "exit": - exit(0) - - # query notes - hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k) - - # render results - render_results(hits, entries, count=args.results_count) \ No newline at end of file + return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k) diff --git a/src/utils/cli.py b/src/utils/cli.py index 7c2f8ea4..89704d5a 100644 --- a/src/utils/cli.py +++ b/src/utils/cli.py @@ -1,6 +1,7 @@ # Standard Packages import argparse import pathlib +from importlib.metadata import version # Internal Packages from src.utils.helpers import resolve_absolute_path @@ -16,9 +17,15 @@ def cli(args=None): parser.add_argument('--host', type=str, default='127.0.0.1', help="Host address of the server. Default: 127.0.0.1") parser.add_argument('--port', '-p', type=int, default=8000, help="Port of the server. Default: 8000") parser.add_argument('--socket', type=pathlib.Path, help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock") + parser.add_argument('--version', '-V', action='store_true', help="Print the installed Khoj version and exit") args = parser.parse_args(args) + if args.version: + # Show version of khoj installed and exit + print(version('khoj-assistant')) + exit(0) + # Normalize config_file path to absolute path args.config_file = resolve_absolute_path(args.config_file) diff --git a/src/utils/config.py b/src/utils/config.py index 2bbea692..316a3d64 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -5,6 +5,7 @@ from pathlib import Path # Internal Packages from src.utils.rawconfig import ConversationProcessorConfig +from src.search_filter.base_filter import BaseFilter class SearchType(str, Enum): @@ -21,23 +22,22 @@ class ProcessorType(str, Enum): class TextSearchModel(): - def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose): + def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k): self.entries = entries self.corpus_embeddings = corpus_embeddings self.bi_encoder = bi_encoder self.cross_encoder = cross_encoder + self.filters = filters self.top_k = top_k - self.verbose = verbose class ImageSearchModel(): - def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder, verbose): + def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder): self.image_encoder = image_encoder self.image_names = image_names self.image_embeddings = image_embeddings self.image_metadata_embeddings = image_metadata_embeddings self.image_encoder = image_encoder - self.verbose = verbose @dataclass @@ -51,12 +51,11 @@ class SearchModels(): class ConversationProcessorConfigModel(): - def __init__(self, processor_config: ConversationProcessorConfig, verbose: bool): + def __init__(self, processor_config: ConversationProcessorConfig): self.openai_api_key = processor_config.openai_api_key self.conversation_logfile = Path(processor_config.conversation_logfile) self.chat_session = '' - self.meta_log = [] - self.verbose = verbose + self.meta_log: dict = {} @dataclass diff --git a/src/utils/constants.py b/src/utils/constants.py index 84c3dfbb..e4840134 100644 --- a/src/utils/constants.py +++ b/src/utils/constants.py @@ -2,7 +2,7 @@ from pathlib import Path app_root_directory = Path(__file__).parent.parent.parent web_directory = app_root_directory / 'src/interface/web/' -empty_escape_sequences = r'\n|\r\t ' +empty_escape_sequences = '\n|\r|\t| ' # default app config to use default_config = { @@ -11,7 +11,8 @@ default_config = { 'input-files': None, 'input-filter': None, 'compressed-jsonl': '~/.khoj/content/org/org.jsonl.gz', - 'embeddings-file': '~/.khoj/content/org/org_embeddings.pt' + 'embeddings-file': '~/.khoj/content/org/org_embeddings.pt', + 'index_heading_entries': False }, 'markdown': { 'input-files': None, diff --git a/src/utils/helpers.py b/src/utils/helpers.py index 52ebc330..df1899f9 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -1,7 +1,11 @@ # Standard Packages -import pathlib +from pathlib import Path import sys +import time +import hashlib from os.path import join +from collections import OrderedDict +from typing import Optional, Union def is_none_or_empty(item): @@ -12,12 +16,12 @@ def to_snake_case_from_dash(item: str): return item.replace('_', '-') -def get_absolute_path(filepath): - return str(pathlib.Path(filepath).expanduser().absolute()) +def get_absolute_path(filepath: Union[str, Path]) -> str: + return str(Path(filepath).expanduser().absolute()) -def resolve_absolute_path(filepath, strict=False): - return pathlib.Path(filepath).expanduser().absolute().resolve(strict=strict) +def resolve_absolute_path(filepath: Union[str, Optional[Path]], strict=False) -> Path: + return Path(filepath).expanduser().absolute().resolve(strict=strict) def get_from_dict(dictionary, *args): @@ -60,4 +64,57 @@ def load_model(model_name, model_dir, model_type, device:str=None): def is_pyinstaller_app(): "Returns true if the app is running from Native GUI created by PyInstaller" - return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS') \ No newline at end of file + return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS') + + +class LRU(OrderedDict): + def __init__(self, *args, capacity=128, **kwargs): + self.capacity = capacity + super().__init__(*args, **kwargs) + + def __getitem__(self, key): + value = super().__getitem__(key) + self.move_to_end(key) + return value + + def __setitem__(self, key, value): + super().__setitem__(key, value) + if len(self) > self.capacity: + oldest = next(iter(self)) + del self[oldest] + + +def mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=None): + # Hash all current and previous entries to identify new entries + start = time.time() + current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), current_entries)) + previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), previous_entries)) + end = time.time() + logger.debug(f"Hash previous, current entries: {end - start} seconds") + + start = time.time() + hash_to_current_entries = dict(zip(current_entry_hashes, current_entries)) + hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries)) + + # All entries that did not exist in the previous set are to be added + new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes) + # All entries that exist in both current and previous sets are kept + existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes) + + # Mark new entries with no ids for later embeddings generation + new_entries = [ + (None, hash_to_current_entries[entry_hash]) + for entry_hash in new_entry_hashes + ] + # Set id of existing entries to their previous ids to reuse their existing encoded embeddings + existing_entries = [ + (previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash]) + for entry_hash in existing_entry_hashes + ] + + existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0]) + entries_with_ids = existing_entries_sorted + new_entries + end = time.time() + logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds") + + return entries_with_ids \ No newline at end of file diff --git a/src/utils/jsonl.py b/src/utils/jsonl.py index 873cdd39..8a034acd 100644 --- a/src/utils/jsonl.py +++ b/src/utils/jsonl.py @@ -1,13 +1,17 @@ # Standard Packages import json import gzip +import logging # Internal Packages from src.utils.constants import empty_escape_sequences from src.utils.helpers import get_absolute_path -def load_jsonl(input_path, verbose=0): +logger = logging.getLogger(__name__) + + +def load_jsonl(input_path): "Read List of JSON objects from JSON line file" # Initialize Variables data = [] @@ -27,13 +31,12 @@ def load_jsonl(input_path, verbose=0): jsonl_file.close() # Log JSONL entries loaded - if verbose > 0: - print(f'Loaded {len(data)} records from {input_path}') + logger.info(f'Loaded {len(data)} records from {input_path}') return data -def dump_jsonl(jsonl_data, output_path, verbose=0): +def dump_jsonl(jsonl_data, output_path): "Write List of JSON objects to JSON line file" # Create output directory, if it doesn't exist output_path.parent.mkdir(parents=True, exist_ok=True) @@ -41,16 +44,14 @@ def dump_jsonl(jsonl_data, output_path, verbose=0): with open(output_path, 'w', encoding='utf-8') as f: f.write(jsonl_data) - if verbose > 0: - print(f'Wrote {len(jsonl_data)} lines to jsonl at {output_path}') + logger.info(f'Wrote jsonl data to {output_path}') -def compress_jsonl_data(jsonl_data, output_path, verbose=0): +def compress_jsonl_data(jsonl_data, output_path): # Create output directory, if it doesn't exist output_path.parent.mkdir(parents=True, exist_ok=True) with gzip.open(output_path, 'wt') as gzip_file: gzip_file.write(jsonl_data) - if verbose > 0: - print(f'Wrote {len(jsonl_data)} lines to gzip compressed jsonl at {output_path}') \ No newline at end of file + logger.info(f'Wrote jsonl data to gzip compressed jsonl at {output_path}') \ No newline at end of file diff --git a/src/utils/rawconfig.py b/src/utils/rawconfig.py index 58d39678..9c19183b 100644 --- a/src/utils/rawconfig.py +++ b/src/utils/rawconfig.py @@ -6,7 +6,7 @@ from typing import List, Optional from pydantic import BaseModel, validator # Internal Packages -from src.utils.helpers import to_snake_case_from_dash +from src.utils.helpers import to_snake_case_from_dash, is_none_or_empty class ConfigBase(BaseModel): class Config: @@ -15,26 +15,27 @@ class ConfigBase(BaseModel): class TextContentConfig(ConfigBase): input_files: Optional[List[Path]] - input_filter: Optional[str] + input_filter: Optional[List[str]] compressed_jsonl: Path embeddings_file: Path + index_heading_entries: Optional[bool] = False @validator('input_filter') def input_filter_or_files_required(cls, input_filter, values, **kwargs): - if input_filter is None and ('input_files' not in values or values["input_files"] is None): + if is_none_or_empty(input_filter) and ('input_files' not in values or values["input_files"] is None): raise ValueError("Either input_filter or input_files required in all content-type. section of Khoj config file") return input_filter class ImageContentConfig(ConfigBase): input_directories: Optional[List[Path]] - input_filter: Optional[str] + input_filter: Optional[List[str]] embeddings_file: Path use_xmp_metadata: bool batch_size: int @validator('input_filter') def input_filter_or_directories_required(cls, input_filter, values, **kwargs): - if input_filter is None and ('input_directories' not in values or values["input_directories"] is None): + if is_none_or_empty(input_filter) and ('input_directories' not in values or values["input_directories"] is None): raise ValueError("Either input_filter or input_directories required in all content-type.image section of Khoj config file") return input_filter diff --git a/src/utils/state.py b/src/utils/state.py index b5c082d6..283d2b5a 100644 --- a/src/utils/state.py +++ b/src/utils/state.py @@ -1,22 +1,25 @@ # Standard Packages from packaging import version + # External Packages import torch from pathlib import Path # Internal Packages from src.utils.config import SearchModels, ProcessorConfigModel +from src.utils.helpers import LRU from src.utils.rawconfig import FullConfig # Application Global State config = FullConfig() model = SearchModels() processor_config = ProcessorConfigModel() -config_file: Path = "" +config_file: Path = None verbose: int = 0 host: str = None port: int = None -cli_args = None +cli_args: list[str] = None +query_cache = LRU() if torch.cuda.is_available(): # Use CUDA GPU diff --git a/src/utils/yaml.py b/src/utils/yaml.py index 46ddb788..a70c6f76 100644 --- a/src/utils/yaml.py +++ b/src/utils/yaml.py @@ -5,12 +5,13 @@ from pathlib import Path import yaml # Internal Packages -from src.utils.helpers import get_absolute_path, resolve_absolute_path from src.utils.rawconfig import FullConfig + # Do not emit tags when dumping to YAML yaml.emitter.Emitter.process_tag = lambda self, *args, **kwargs: None + def save_config_to_file(yaml_config: dict, yaml_config_file: Path): "Write config to YML file" # Create output directory, if it doesn't exist diff --git a/tests/conftest.py b/tests/conftest.py index b70deb87..f6c0a7ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,78 +1,65 @@ -# Standard Packages +# External Packages import pytest # Internal Packages from src.search_type import image_search, text_search +from src.utils.config import SearchType +from src.utils.helpers import resolve_absolute_path from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig from src.processor.org_mode.org_to_jsonl import org_to_jsonl -from src.utils import state +from src.search_filter.date_filter import DateFilter +from src.search_filter.word_filter import WordFilter +from src.search_filter.file_filter import FileFilter @pytest.fixture(scope='session') -def search_config(tmp_path_factory): - model_dir = tmp_path_factory.mktemp('data') - +def search_config() -> SearchConfig: + model_dir = resolve_absolute_path('~/.khoj/search') + model_dir.mkdir(parents=True, exist_ok=True) search_config = SearchConfig() search_config.symmetric = TextSearchConfig( encoder = "sentence-transformers/all-MiniLM-L6-v2", cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2", - model_directory = model_dir + model_directory = model_dir / 'symmetric/' ) search_config.asymmetric = TextSearchConfig( encoder = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2", - model_directory = model_dir + model_directory = model_dir / 'asymmetric/' ) search_config.image = ImageSearchConfig( encoder = "sentence-transformers/clip-ViT-B-32", - model_directory = model_dir + model_directory = model_dir / 'image/' ) return search_config @pytest.fixture(scope='session') -def model_dir(search_config): - model_dir = search_config.asymmetric.model_directory +def content_config(tmp_path_factory, search_config: SearchConfig): + content_dir = tmp_path_factory.mktemp('content') # Generate Image Embeddings from Test Images content_config = ContentConfig() content_config.image = ImageContentConfig( input_directories = ['tests/data/images'], - embeddings_file = model_dir.joinpath('image_embeddings.pt'), - batch_size = 10, + embeddings_file = content_dir.joinpath('image_embeddings.pt'), + batch_size = 1, use_xmp_metadata = False) - image_search.setup(content_config.image, search_config.image, regenerate=False, verbose=True) + image_search.setup(content_config.image, search_config.image, regenerate=False) # Generate Notes Embeddings from Test Notes content_config.org = TextContentConfig( input_files = None, - input_filter = 'tests/data/org/*.org', - compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'), - embeddings_file = model_dir.joinpath('note_embeddings.pt')) + input_filter = ['tests/data/org/*.org'], + compressed_jsonl = content_dir.joinpath('notes.jsonl.gz'), + embeddings_file = content_dir.joinpath('note_embeddings.pt')) - text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, verbose=True) + filters = [DateFilter(), WordFilter(), FileFilter()] + text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) - return model_dir - - -@pytest.fixture(scope='session') -def content_config(model_dir): - content_config = ContentConfig() - content_config.org = TextContentConfig( - input_files = None, - input_filter = 'tests/data/org/*.org', - compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'), - embeddings_file = model_dir.joinpath('note_embeddings.pt')) - - content_config.image = ImageContentConfig( - input_directories = ['tests/data/images'], - embeddings_file = model_dir.joinpath('image_embeddings.pt'), - batch_size = 1, - use_xmp_metadata = False) - - return content_config \ No newline at end of file + return content_config diff --git a/tests/data/config.yml b/tests/data/config.yml index b002b32a..41603972 100644 --- a/tests/data/config.yml +++ b/tests/data/config.yml @@ -1,9 +1,10 @@ content-type: org: input-files: [ "~/first_from_config.org", "~/second_from_config.org" ] - input-filter: "*.org" + input-filter: ["*.org", "~/notes/*.org"] compressed-jsonl: ".notes.json.gz" embeddings-file: ".note_embeddings.pt" + index-header-entries: true search-type: asymmetric: diff --git a/tests/test_beancount_to_jsonl.py b/tests/test_beancount_to_jsonl.py new file mode 100644 index 00000000..51a4dffd --- /dev/null +++ b/tests/test_beancount_to_jsonl.py @@ -0,0 +1,112 @@ +# Standard Packages +import json + +# Internal Packages +from src.processor.ledger.beancount_to_jsonl import extract_beancount_transactions, convert_transactions_to_maps, convert_transaction_maps_to_jsonl, get_beancount_files + + +def test_no_transactions_in_file(tmp_path): + "Handle file with no transactions." + # Arrange + entry = f''' + - Bullet point 1 + - Bullet point 2 + ''' + beancount_file = create_file(tmp_path, entry) + + # Act + # Extract Entries from specified Beancount files + entry_nodes, file_to_entries = extract_beancount_transactions(beancount_files=[beancount_file]) + + # Process Each Entry from All Beancount Files + jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entry_nodes, file_to_entries)) + jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + + # Assert + assert len(jsonl_data) == 0 + + +def test_single_beancount_transaction_to_jsonl(tmp_path): + "Convert transaction from single file to jsonl." + # Arrange + entry = f''' +1984-04-01 * "Payee" "Narration" +Expenses:Test:Test 1.00 KES +Assets:Test:Test -1.00 KES + ''' + beancount_file = create_file(tmp_path, entry) + + # Act + # Extract Entries from specified Beancount files + entries, entry_to_file_map = extract_beancount_transactions(beancount_files=[beancount_file]) + + # Process Each Entry from All Beancount Files + jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entries, entry_to_file_map)) + jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + + # Assert + assert len(jsonl_data) == 1 + + +def test_multiple_transactions_to_jsonl(tmp_path): + "Convert multiple transactions from single file to jsonl." + # Arrange + entry = f''' +1984-04-01 * "Payee" "Narration" +Expenses:Test:Test 1.00 KES +Assets:Test:Test -1.00 KES +\t\r +1984-04-01 * "Payee" "Narration" +Expenses:Test:Test 1.00 KES +Assets:Test:Test -1.00 KES +''' + + beancount_file = create_file(tmp_path, entry) + + # Act + # Extract Entries from specified Beancount files + entries, entry_to_file_map = extract_beancount_transactions(beancount_files=[beancount_file]) + + # Process Each Entry from All Beancount Files + jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entries, entry_to_file_map)) + jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + + # Assert + assert len(jsonl_data) == 2 + + +def test_get_beancount_files(tmp_path): + "Ensure Beancount files specified via input-filter, input-files extracted" + # Arrange + # Include via input-filter globs + group1_file1 = create_file(tmp_path, filename="group1-file1.bean") + group1_file2 = create_file(tmp_path, filename="group1-file2.bean") + group2_file1 = create_file(tmp_path, filename="group2-file1.beancount") + group2_file2 = create_file(tmp_path, filename="group2-file2.beancount") + # Include via input-file field + file1 = create_file(tmp_path, filename="ledger.bean") + # Not included by any filter + create_file(tmp_path, filename="not-included-ledger.bean") + create_file(tmp_path, filename="not-included-text.txt") + + expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1])) + + # Setup input-files, input-filters + input_files = [tmp_path / 'ledger.bean'] + input_filter = [tmp_path / 'group1*.bean', tmp_path / 'group2*.beancount'] + + # Act + extracted_org_files = get_beancount_files(input_files, input_filter) + + # Assert + assert len(extracted_org_files) == 5 + assert extracted_org_files == expected_files + + +# Helper Functions +def create_file(tmp_path, entry=None, filename="ledger.beancount"): + beancount_file = tmp_path / filename + beancount_file.touch() + if entry: + beancount_file.write_text(entry) + return beancount_file diff --git a/tests/test_cli.py b/tests/test_cli.py index 4cbf1209..3c99f424 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,9 +2,6 @@ from pathlib import Path from random import random -# External Modules -import pytest - # Internal Packages from src.utils.cli import cli from src.utils.helpers import resolve_absolute_path diff --git a/tests/test_client.py b/tests/test_client.py index 38b98c1f..d405a044 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,17 +1,20 @@ # Standard Modules from io import BytesIO from PIL import Image +from urllib.parse import quote + # External Packages from fastapi.testclient import TestClient -import pytest # Internal Packages from src.main import app from src.utils.state import model, config from src.search_type import text_search, image_search from src.utils.rawconfig import ContentConfig, SearchConfig -from src.processor.org_mode import org_to_jsonl +from src.processor.org_mode.org_to_jsonl import org_to_jsonl +from src.search_filter.word_filter import WordFilter +from src.search_filter.file_filter import FileFilter # Arrange @@ -22,7 +25,7 @@ client = TestClient(app) # ---------------------------------------------------------------------------------------------------- def test_search_with_invalid_content_type(): # Arrange - user_query = "How to call Khoj from Emacs?" + user_query = quote("How to call Khoj from Emacs?") # Act response = client.get(f"/search?q={user_query}&t=invalid_content_type") @@ -116,7 +119,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig def test_notes_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) - user_query = "How to git install application?" + user_query = quote("How to git install application?") # Act response = client.get(f"/search?q={user_query}&n=1&t=org&r=true") @@ -129,17 +132,35 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig # ---------------------------------------------------------------------------------------------------- -def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig): +def test_notes_search_with_only_filters(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) - user_query = "How to git install application? +Emacs" + filters = [WordFilter(), FileFilter()] + model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) + user_query = quote('+"Emacs" file:"*.org"') # Act response = client.get(f"/search?q={user_query}&n=1&t=org") # Assert assert response.status_code == 200 - # assert actual_data contains explicitly included word "Emacs" + # assert actual_data contains word "Emacs" + search_result = response.json()[0]["entry"] + assert "Emacs" in search_result + + +# ---------------------------------------------------------------------------------------------------- +def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig): + # Arrange + filters = [WordFilter()] + model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) + user_query = quote('How to git install application? +"Emacs"') + + # Act + response = client.get(f"/search?q={user_query}&n=1&t=org") + + # Assert + assert response.status_code == 200 + # assert actual_data contains word "Emacs" search_result = response.json()[0]["entry"] assert "Emacs" in search_result @@ -147,14 +168,15 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_ # ---------------------------------------------------------------------------------------------------- def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) - user_query = "How to git install application? -clone" + filters = [WordFilter()] + model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) + user_query = quote('How to git install application? -"clone"') # Act response = client.get(f"/search?q={user_query}&n=1&t=org") # Assert assert response.status_code == 200 - # assert actual_data does not contains explicitly excluded word "Emacs" + # assert actual_data does not contains word "Emacs" search_result = response.json()[0]["entry"] assert "clone" not in search_result diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index 88e31c86..345c5c4f 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -18,40 +18,34 @@ def test_date_filter(): {'compiled': '', 'raw': 'Entry with date:1984-04-02'}] q_with_no_date_filter = 'head tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_no_date_filter, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries) assert ret_query == 'head tail' - assert len(ret_emb) == 3 - assert ret_entries == entries + assert entry_indices == {0, 1, 2} q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries) assert ret_query == 'head tail' - assert len(ret_emb) == 0 - assert ret_entries == [] + assert entry_indices == set() query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' - assert ret_entries == [entries[2]] - assert len(ret_emb) == 1 + assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' - assert ret_entries == [entries[1]] - assert len(ret_emb) == 1 + assert entry_indices == {1} query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' - assert ret_entries == [entries[2]] - assert len(ret_emb) == 1 + assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' - assert ret_entries == [entries[1], entries[2]] - assert len(ret_emb) == 2 + assert entry_indices == {1, 2} def test_extract_date_range(): diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py new file mode 100644 index 00000000..3f9c22b3 --- /dev/null +++ b/tests/test_file_filter.py @@ -0,0 +1,112 @@ +# External Packages +import torch + +# Application Packages +from src.search_filter.file_filter import FileFilter + + +def test_no_file_filter(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) + + # Assert + assert can_filter == False + assert ret_query == 'head tail' + assert entry_indices == {0, 1, 2, 3} + + +def test_file_filter_with_non_existent_file(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"nonexistent.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {} + + +def test_single_file_filter(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"file 1.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {0, 2} + + +def test_file_filter_with_partial_match(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"1.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {0, 2} + + +def test_file_filter_with_regex_match(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"*.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {0, 1, 2, 3} + + +def test_multiple_file_filter(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {0, 1, 2, 3} + + +def arrange_content(): + embeddings = torch.randn(4, 10) + entries = [ + {'compiled': '', 'raw': 'First Entry', 'file': 'file 1.org'}, + {'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'}, + {'compiled': '', 'raw': 'Third Entry', 'file': 'file 1.org'}, + {'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}] + + return embeddings, entries diff --git a/tests/test_helpers.py b/tests/test_helpers.py index d4f06e6d..c9b1cd75 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -28,3 +28,18 @@ def test_merge_dicts(): # do not override existing key in priority_dict with default dict assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'a': 2}) == {'a': 1} + + +def test_lru_cache(): + # Test initializing cache + cache = helpers.LRU({'a': 1, 'b': 2}, capacity=2) + assert cache == {'a': 1, 'b': 2} + + # Test capacity overflow + cache['c'] = 3 + assert cache == {'b': 2, 'c': 3} + + # Test delete least recently used item from LRU cache on capacity overflow + cache['b'] # accessing 'b' makes it the most recently used item + cache['d'] = 4 # so 'c' is deleted from the cache instead of 'b' + assert cache == {'b': 2, 'd': 4} diff --git a/tests/test_image_search.py b/tests/test_image_search.py index 80c4fdf6..ad374da1 100644 --- a/tests/test_image_search.py +++ b/tests/test_image_search.py @@ -48,8 +48,13 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig image_files_url='/static/images', count=1) - actual_image = Image.open(output_directory.joinpath(Path(results[0]["entry"]).name)) + actual_image_path = output_directory.joinpath(Path(results[0]["entry"]).name) + actual_image = Image.open(actual_image_path) expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name)) # Assert assert expected_image == actual_image + + # Cleanup + # Delete the image files copied to results directory + actual_image_path.unlink() diff --git a/tests/test_markdown_to_jsonl.py b/tests/test_markdown_to_jsonl.py new file mode 100644 index 00000000..89c471d8 --- /dev/null +++ b/tests/test_markdown_to_jsonl.py @@ -0,0 +1,109 @@ +# Standard Packages +import json + +# Internal Packages +from src.processor.markdown.markdown_to_jsonl import extract_markdown_entries, convert_markdown_maps_to_jsonl, convert_markdown_entries_to_maps, get_markdown_files + + +def test_markdown_file_with_no_headings_to_jsonl(tmp_path): + "Convert files with no heading to jsonl." + # Arrange + entry = f''' + - Bullet point 1 + - Bullet point 2 + ''' + markdownfile = create_file(tmp_path, entry) + + # Act + # Extract Entries from specified Markdown files + entry_nodes, file_to_entries = extract_markdown_entries(markdown_files=[markdownfile]) + + # Process Each Entry from All Notes Files + jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entry_nodes, file_to_entries)) + jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + + # Assert + assert len(jsonl_data) == 1 + + +def test_single_markdown_entry_to_jsonl(tmp_path): + "Convert markdown entry from single file to jsonl." + # Arrange + entry = f'''### Heading + \t\r + Body Line 1 + ''' + markdownfile = create_file(tmp_path, entry) + + # Act + # Extract Entries from specified Markdown files + entries, entry_to_file_map = extract_markdown_entries(markdown_files=[markdownfile]) + + # Process Each Entry from All Notes Files + jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entries, entry_to_file_map)) + jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + + # Assert + assert len(jsonl_data) == 1 + + +def test_multiple_markdown_entries_to_jsonl(tmp_path): + "Convert multiple markdown entries from single file to jsonl." + # Arrange + entry = f''' +### Heading 1 + \t\r + Heading 1 Body Line 1 +### Heading 2 + \t\r + Heading 2 Body Line 2 + ''' + markdownfile = create_file(tmp_path, entry) + + # Act + # Extract Entries from specified Markdown files + entries, entry_to_file_map = extract_markdown_entries(markdown_files=[markdownfile]) + + # Process Each Entry from All Notes Files + jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entries, entry_to_file_map)) + jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + + # Assert + assert len(jsonl_data) == 2 + + +def test_get_markdown_files(tmp_path): + "Ensure Markdown files specified via input-filter, input-files extracted" + # Arrange + # Include via input-filter globs + group1_file1 = create_file(tmp_path, filename="group1-file1.md") + group1_file2 = create_file(tmp_path, filename="group1-file2.md") + group2_file1 = create_file(tmp_path, filename="group2-file1.markdown") + group2_file2 = create_file(tmp_path, filename="group2-file2.markdown") + # Include via input-file field + file1 = create_file(tmp_path, filename="notes.md") + # Not included by any filter + create_file(tmp_path, filename="not-included-markdown.md") + create_file(tmp_path, filename="not-included-text.txt") + + expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1])) + + # Setup input-files, input-filters + input_files = [tmp_path / 'notes.md'] + input_filter = [tmp_path / 'group1*.md', tmp_path / 'group2*.markdown'] + + # Act + extracted_org_files = get_markdown_files(input_files, input_filter) + + # Assert + assert len(extracted_org_files) == 5 + assert extracted_org_files == expected_files + + +# Helper Functions +def create_file(tmp_path, entry=None, filename="test.md"): + markdown_file = tmp_path / filename + markdown_file.touch() + if entry: + markdown_file.write_text(entry) + return markdown_file diff --git a/tests/test_org_to_jsonl.py b/tests/test_org_to_jsonl.py index cadd4a6a..8a2f58ba 100644 --- a/tests/test_org_to_jsonl.py +++ b/tests/test_org_to_jsonl.py @@ -1,33 +1,38 @@ # Standard Packages import json -from posixpath import split # Internal Packages -from src.processor.org_mode.org_to_jsonl import convert_org_entries_to_jsonl, extract_org_entries +from src.processor.org_mode.org_to_jsonl import convert_org_entries_to_jsonl, convert_org_nodes_to_entries, extract_org_entries, get_org_files from src.utils.helpers import is_none_or_empty -def test_entry_with_empty_body_line_to_jsonl(tmp_path): - '''Ensure entries with empty body are ignored. +def test_configure_heading_entry_to_jsonl(tmp_path): + '''Ensure entries with empty body are ignored, unless explicitly configured to index heading entries. Property drawers not considered Body. Ignore control characters for evaluating if Body empty.''' # Arrange entry = f'''*** Heading :PROPERTIES: :ID: 42-42-42 :END: - \t\r\n + \t \r ''' orgfile = create_file(tmp_path, entry) - # Act - # Extract Entries from specified Org files - entries = extract_org_entries(org_files=[orgfile]) + for index_heading_entries in [True, False]: + # Act + # Extract entries into jsonl from specified Org files + jsonl_string = convert_org_entries_to_jsonl(convert_org_nodes_to_entries( + *extract_org_entries(org_files=[orgfile]), + index_heading_entries=index_heading_entries)) + jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] - # Process Each Entry from All Notes Files - jsonl_data = convert_org_entries_to_jsonl(entries) - - # Assert - assert is_none_or_empty(jsonl_data) + # Assert + if index_heading_entries: + # Entry with empty body indexed when index_heading_entries set to True + assert len(jsonl_data) == 1 + else: + # Entry with empty body ignored when index_heading_entries set to False + assert is_none_or_empty(jsonl_data) def test_entry_with_body_to_jsonl(tmp_path): @@ -37,15 +42,38 @@ def test_entry_with_body_to_jsonl(tmp_path): :PROPERTIES: :ID: 42-42-42 :END: - \t\r\nBody Line 1\n + \t\r + Body Line 1 ''' orgfile = create_file(tmp_path, entry) # Act # Extract Entries from specified Org files - entries = extract_org_entries(org_files=[orgfile]) + entries, entry_to_file_map = extract_org_entries(org_files=[orgfile]) # Process Each Entry from All Notes Files + jsonl_string = convert_org_entries_to_jsonl(convert_org_nodes_to_entries(entries, entry_to_file_map)) + jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + + # Assert + assert len(jsonl_data) == 1 + + +def test_file_with_no_headings_to_jsonl(tmp_path): + "Ensure files with no heading, only body text are loaded." + # Arrange + entry = f''' + - Bullet point 1 + - Bullet point 2 + ''' + orgfile = create_file(tmp_path, entry) + + # Act + # Extract Entries from specified Org files + entry_nodes, file_to_entries = extract_org_entries(org_files=[orgfile]) + + # Process Each Entry from All Notes Files + entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries) jsonl_string = convert_org_entries_to_jsonl(entries) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] @@ -53,10 +81,38 @@ def test_entry_with_body_to_jsonl(tmp_path): assert len(jsonl_data) == 1 +def test_get_org_files(tmp_path): + "Ensure Org files specified via input-filter, input-files extracted" + # Arrange + # Include via input-filter globs + group1_file1 = create_file(tmp_path, filename="group1-file1.org") + group1_file2 = create_file(tmp_path, filename="group1-file2.org") + group2_file1 = create_file(tmp_path, filename="group2-file1.org") + group2_file2 = create_file(tmp_path, filename="group2-file2.org") + # Include via input-file field + orgfile1 = create_file(tmp_path, filename="orgfile1.org") + # Not included by any filter + create_file(tmp_path, filename="orgfile2.org") + create_file(tmp_path, filename="text1.txt") + + expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1])) + + # Setup input-files, input-filters + input_files = [tmp_path / 'orgfile1.org'] + input_filter = [tmp_path / 'group1*.org', tmp_path / 'group2*.org'] + + # Act + extracted_org_files = get_org_files(input_files, input_filter) + + # Assert + assert len(extracted_org_files) == 5 + assert extracted_org_files == expected_files + + # Helper Functions -def create_file(tmp_path, entry, filename="test.org"): - org_file = tmp_path / f"notes/{filename}" - org_file.parent.mkdir() +def create_file(tmp_path, entry=None, filename="test.org"): + org_file = tmp_path / filename org_file.touch() - org_file.write_text(entry) - return org_file \ No newline at end of file + if entry: + org_file.write_text(entry) + return org_file diff --git a/tests/test_orgnode.py b/tests/test_orgnode.py index 186eaaec..d36cca79 100644 --- a/tests/test_orgnode.py +++ b/tests/test_orgnode.py @@ -1,13 +1,33 @@ # Standard Packages import datetime -from os.path import relpath -from pathlib import Path # Internal Packages from src.processor.org_mode import orgnode # Test +# ---------------------------------------------------------------------------------------------------- +def test_parse_entry_with_no_headings(tmp_path): + "Test parsing of entry with minimal fields" + # Arrange + entry = f'''Body Line 1''' + orgfile = create_file(tmp_path, entry) + + # Act + entries = orgnode.makelist(orgfile) + + # Assert + assert len(entries) == 1 + assert entries[0].heading == f'{orgfile}' + assert entries[0].tags == list() + assert entries[0].body == "Body Line 1" + assert entries[0].priority == "" + assert entries[0].Property("ID") == "" + assert entries[0].closed == "" + assert entries[0].scheduled == "" + assert entries[0].deadline == "" + + # ---------------------------------------------------------------------------------------------------- def test_parse_minimal_entry(tmp_path): "Test parsing of entry with minimal fields" @@ -22,14 +42,14 @@ Body Line 1''' # Assert assert len(entries) == 1 - assert entries[0].Heading() == "Heading" - assert entries[0].Tags() == set() - assert entries[0].Body() == "Body Line 1" - assert entries[0].Priority() == "" + assert entries[0].heading == "Heading" + assert entries[0].tags == list() + assert entries[0].body == "Body Line 1" + assert entries[0].priority == "" assert entries[0].Property("ID") == "" - assert entries[0].Closed() == "" - assert entries[0].Scheduled() == "" - assert entries[0].Deadline() == "" + assert entries[0].closed == "" + assert entries[0].scheduled == "" + assert entries[0].deadline == "" # ---------------------------------------------------------------------------------------------------- @@ -55,16 +75,44 @@ Body Line 2''' # Assert assert len(entries) == 1 - assert entries[0].Heading() == "Heading" - assert entries[0].Todo() == "DONE" - assert entries[0].Tags() == {"Tag1", "TAG2", "tag3"} - assert entries[0].Body() == "- Clocked Log 1\nBody Line 1\nBody Line 2" - assert entries[0].Priority() == "A" + assert entries[0].heading == "Heading" + assert entries[0].todo == "DONE" + assert entries[0].tags == ["Tag1", "TAG2", "tag3"] + assert entries[0].body == "- Clocked Log 1\nBody Line 1\nBody Line 2" + assert entries[0].priority == "A" assert entries[0].Property("ID") == "id:123-456-789-4234-1231" - assert entries[0].Closed() == datetime.date(1984,4,1) - assert entries[0].Scheduled() == datetime.date(1984,4,1) - assert entries[0].Deadline() == datetime.date(1984,4,1) - assert entries[0].Logbook() == [(datetime.datetime(1984,4,1,9,0,0), datetime.datetime(1984,4,1,12,0,0))] + assert entries[0].closed == datetime.date(1984,4,1) + assert entries[0].scheduled == datetime.date(1984,4,1) + assert entries[0].deadline == datetime.date(1984,4,1) + assert entries[0].logbook == [(datetime.datetime(1984,4,1,9,0,0), datetime.datetime(1984,4,1,12,0,0))] + + +# ---------------------------------------------------------------------------------------------------- +def test_render_entry_with_property_drawer_and_empty_body(tmp_path): + "Render heading entry with property drawer" + # Arrange + entry_to_render = f''' +*** [#A] Heading1 :tag1: + :PROPERTIES: + :ID: 111-111-111-1111-1111 + :END: +\t\r \n +''' + orgfile = create_file(tmp_path, entry_to_render) + + expected_entry = f'''*** [#A] Heading1 :tag1: +:PROPERTIES: +:LINE: file:{orgfile}::2 +:ID: id:111-111-111-1111-1111 +:SOURCE: [[file:{orgfile}::*Heading1]] +:END: +''' + + # Act + parsed_entries = orgnode.makelist(orgfile) + + # Assert + assert f'{parsed_entries[0]}' == expected_entry # ---------------------------------------------------------------------------------------------------- @@ -81,18 +129,17 @@ Body Line 1 Body Line 2 ''' orgfile = create_file(tmp_path, entry) - normalized_orgfile = f'~/{relpath(orgfile, start=Path.home())}' # Act entries = orgnode.makelist(orgfile) # Assert # SOURCE link rendered with Heading - assert f':SOURCE: [[file:{normalized_orgfile}::*{entries[0].Heading()}]]' in f'{entries[0]}' + assert f':SOURCE: [[file:{orgfile}::*{entries[0].heading}]]' in f'{entries[0]}' # ID link rendered with ID assert f':ID: id:123-456-789-4234-1231' in f'{entries[0]}' # LINE link rendered with line number - assert f':LINE: file:{normalized_orgfile}::2' in f'{entries[0]}' + assert f':LINE: file:{orgfile}::2' in f'{entries[0]}' # ---------------------------------------------------------------------------------------------------- @@ -113,10 +160,9 @@ Body Line 1''' # Assert assert len(entries) == 1 # parsed heading from entry - assert entries[0].Heading() == "Heading[1]" + assert entries[0].heading == "Heading[1]" # ensure SOURCE link has square brackets in filename, heading escaped in rendered entries - normalized_orgfile = f'~/{relpath(orgfile, start=Path.home())}' - escaped_orgfile = f'{normalized_orgfile}'.replace("[1]", "\\[1\\]") + escaped_orgfile = f'{orgfile}'.replace("[1]", "\\[1\\]") assert f':SOURCE: [[file:{escaped_orgfile}::*Heading\[1\]' in f'{entries[0]}' @@ -156,16 +202,86 @@ Body 2 # Assert assert len(entries) == 2 for index, entry in enumerate(entries): - assert entry.Heading() == f"Heading{index+1}" - assert entry.Todo() == "FAILED" if index == 0 else "CANCELLED" - assert entry.Tags() == {f"tag{index+1}"} - assert entry.Body() == f"- Clocked Log {index+1}\nBody {index+1}\n\n" - assert entry.Priority() == "A" + assert entry.heading == f"Heading{index+1}" + assert entry.todo == "FAILED" if index == 0 else "CANCELLED" + assert entry.tags == [f"tag{index+1}"] + assert entry.body == f"- Clocked Log {index+1}\nBody {index+1}\n\n" + assert entry.priority == "A" assert entry.Property("ID") == f"id:123-456-789-4234-000{index+1}" - assert entry.Closed() == datetime.date(1984,4,index+1) - assert entry.Scheduled() == datetime.date(1984,4,index+1) - assert entry.Deadline() == datetime.date(1984,4,index+1) - assert entry.Logbook() == [(datetime.datetime(1984,4,index+1,9,0,0), datetime.datetime(1984,4,index+1,12,0,0))] + assert entry.closed == datetime.date(1984,4,index+1) + assert entry.scheduled == datetime.date(1984,4,index+1) + assert entry.deadline == datetime.date(1984,4,index+1) + assert entry.logbook == [(datetime.datetime(1984,4,index+1,9,0,0), datetime.datetime(1984,4,index+1,12,0,0))] + + +# ---------------------------------------------------------------------------------------------------- +def test_parse_entry_with_empty_title(tmp_path): + "Test parsing of entry with minimal fields" + # Arrange + entry = f'''#+TITLE: +Body Line 1''' + orgfile = create_file(tmp_path, entry) + + # Act + entries = orgnode.makelist(orgfile) + + # Assert + assert len(entries) == 1 + assert entries[0].heading == f'{orgfile}' + assert entries[0].tags == list() + assert entries[0].body == "Body Line 1" + assert entries[0].priority == "" + assert entries[0].Property("ID") == "" + assert entries[0].closed == "" + assert entries[0].scheduled == "" + assert entries[0].deadline == "" + + +# ---------------------------------------------------------------------------------------------------- +def test_parse_entry_with_title_and_no_headings(tmp_path): + "Test parsing of entry with minimal fields" + # Arrange + entry = f'''#+TITLE: test +Body Line 1''' + orgfile = create_file(tmp_path, entry) + + # Act + entries = orgnode.makelist(orgfile) + + # Assert + assert len(entries) == 1 + assert entries[0].heading == 'test' + assert entries[0].tags == list() + assert entries[0].body == "Body Line 1" + assert entries[0].priority == "" + assert entries[0].Property("ID") == "" + assert entries[0].closed == "" + assert entries[0].scheduled == "" + assert entries[0].deadline == "" + + +# ---------------------------------------------------------------------------------------------------- +def test_parse_entry_with_multiple_titles_and_no_headings(tmp_path): + "Test parsing of entry with minimal fields" + # Arrange + entry = f'''#+TITLE: title1 +Body Line 1 +#+TITLE: title2 ''' + orgfile = create_file(tmp_path, entry) + + # Act + entries = orgnode.makelist(orgfile) + + # Assert + assert len(entries) == 1 + assert entries[0].heading == 'title1 title2' + assert entries[0].tags == list() + assert entries[0].body == "Body Line 1\n" + assert entries[0].priority == "" + assert entries[0].Property("ID") == "" + assert entries[0].closed == "" + assert entries[0].scheduled == "" + assert entries[0].deadline == "" # Helper Functions @@ -174,4 +290,4 @@ def create_file(tmp_path, entry, filename="test.org"): org_file.parent.mkdir() org_file.touch() org_file.write_text(entry) - return org_file \ No newline at end of file + return org_file diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 39fed92e..6744566d 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -1,6 +1,10 @@ # System Packages +from copy import deepcopy from pathlib import Path +# External Packages +import pytest + # Internal Packages from src.utils.state import model from src.search_type import text_search @@ -9,6 +13,39 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl # Test +# ---------------------------------------------------------------------------------------------------- +def test_asymmetric_setup_with_missing_file_raises_error(content_config: ContentConfig, search_config: SearchConfig): + # Arrange + file_to_index = Path(content_config.org.input_filter[0]).parent / "new_file_to_index.org" + new_org_content_config = deepcopy(content_config.org) + new_org_content_config.input_files = [f'{file_to_index}'] + new_org_content_config.input_filter = None + + # Act + # Generate notes embeddings during asymmetric setup + with pytest.raises(FileNotFoundError): + text_search.setup(org_to_jsonl, new_org_content_config, search_config.asymmetric, regenerate=True) + + +# ---------------------------------------------------------------------------------------------------- +def test_asymmetric_setup_with_empty_file_raises_error(content_config: ContentConfig, search_config: SearchConfig): + # Arrange + file_to_index = Path(content_config.org.input_filter[0]).parent / "new_file_to_index.org" + file_to_index.touch() + new_org_content_config = deepcopy(content_config.org) + new_org_content_config.input_files = [f'{file_to_index}'] + new_org_content_config.input_filter = None + + # Act + # Generate notes embeddings during asymmetric setup + with pytest.raises(ValueError, match=r'^No valid entries found*'): + text_search.setup(org_to_jsonl, new_org_content_config, search_config.asymmetric, regenerate=True) + + # Cleanup + # delete created test file + file_to_index.unlink() + + # ---------------------------------------------------------------------------------------------------- def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig): # Act @@ -23,7 +60,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo # ---------------------------------------------------------------------------------------------------- def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) query = "How to git install application?" # Act @@ -51,7 +88,7 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.corpus_embeddings) == 10 - file_to_add_on_reload = Path(content_config.org.input_filter).parent / "reload.org" + file_to_add_on_reload = Path(content_config.org.input_filter[0]).parent / "reload.org" content_config.org.input_files = [f'{file_to_add_on_reload}'] # append Org-Mode Entry to first Org Input File in Config @@ -77,3 +114,32 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC # delete reload test file added content_config.org.input_files = [] file_to_add_on_reload.unlink() + + +# ---------------------------------------------------------------------------------------------------- +def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig): + # Arrange + initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) + + assert len(initial_notes_model.entries) == 10 + assert len(initial_notes_model.corpus_embeddings) == 10 + + file_to_add_on_update = Path(content_config.org.input_filter[0]).parent / "update.org" + content_config.org.input_files = [f'{file_to_add_on_update}'] + + # append Org-Mode Entry to first Org Input File in Config + with open(file_to_add_on_update, "w") as f: + f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") + + # Act + # update embeddings, entries with the newly added note + initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + + # verify new entry added in updated embeddings, entries + assert len(initial_notes_model.entries) == 11 + assert len(initial_notes_model.corpus_embeddings) == 11 + + # Cleanup + # delete file added for update testing + content_config.org.input_files = [] + file_to_add_on_update.unlink() diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py new file mode 100644 index 00000000..db23c2c6 --- /dev/null +++ b/tests/test_word_filter.py @@ -0,0 +1,77 @@ +# Application Packages +from src.search_filter.word_filter import WordFilter +from src.utils.config import SearchType + + +def test_no_word_filter(): + # Arrange + word_filter = WordFilter() + entries = arrange_content() + q_with_no_filter = 'head tail' + + # Act + can_filter = word_filter.can_filter(q_with_no_filter) + ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries) + + # Assert + assert can_filter == False + assert ret_query == 'head tail' + assert entry_indices == {0, 1, 2, 3} + + +def test_word_exclude_filter(): + # Arrange + word_filter = WordFilter() + entries = arrange_content() + q_with_exclude_filter = 'head -"exclude_word" tail' + + # Act + can_filter = word_filter.can_filter(q_with_exclude_filter) + ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {0, 2} + + +def test_word_include_filter(): + # Arrange + word_filter = WordFilter() + entries = arrange_content() + query_with_include_filter = 'head +"include_word" tail' + + # Act + can_filter = word_filter.can_filter(query_with_include_filter) + ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {2, 3} + + +def test_word_include_and_exclude_filter(): + # Arrange + word_filter = WordFilter() + entries = arrange_content() + query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail' + + # Act + can_filter = word_filter.can_filter(query_with_include_and_exclude_filter) + ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {2} + + +def arrange_content(): + entries = [ + {'compiled': '', 'raw': 'Minimal Entry'}, + {'compiled': '', 'raw': 'Entry with exclude_word'}, + {'compiled': '', 'raw': 'Entry with include_word'}, + {'compiled': '', 'raw': 'Entry with include_word and exclude_word'}] + + return entries